1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
|
import sys
import threading
import time
import unittest
from _mssql import MSSQLDatabaseException
from .helpers import mssqlconn, StoredProc, mark_slow
error_sproc = StoredProc(
"pymssqlErrorThreadTest",
args=(),
body="SELECT unknown_column FROM unknown_table")
class _TestingThread(threading.Thread):
def __init__(self):
super(_TestingThread, self).__init__()
self.results = []
self.exc = None
def run(self):
try:
with mssqlconn() as mssql:
for i in range(0, 1000):
num = mssql.execute_scalar('SELECT %d', (i,))
assert num == i
self.results.append(num)
except Exception as exc:
self.exc = exc
class _TestingErrorThread(_TestingThread):
def run(self):
try:
with mssqlconn() as mssql:
mssql.execute_query('SELECT unknown_column')
except Exception as exc:
self.exc = exc
class _SprocTestingErrorThread(_TestingThread):
def run(self):
try:
with mssqlconn() as mssql:
error_sproc.execute(mssql=mssql)
except Exception as exc:
self.exc = exc
class ThreadedTests(unittest.TestCase):
def run_threads(self, num, thread_class):
threads = [thread_class() for _ in range(num)]
for thread in threads:
thread.start()
results = []
exceptions = []
while len(threads) > 0:
sys.stdout.write(".")
sys.stdout.flush()
for thread in threads:
if not thread.is_alive():
threads.remove(thread)
if thread.results:
results.append(thread.results)
if thread.exc:
exceptions.append(thread.exc)
time.sleep(5)
sys.stdout.write(" ")
sys.stdout.flush()
return results, exceptions
@mark_slow
def testThreadedUse(self):
results, exceptions = self.run_threads(
num=50,
thread_class=_TestingThread)
self.assertEqual(len(exceptions), 0)
for result in results:
self.assertEqual(result, list(range(0, 1000)))
@mark_slow
def testErrorThreadedUse(self):
results, exceptions = self.run_threads(
num=2,
thread_class=_TestingErrorThread)
self.assertEqual(len(exceptions), 2)
for exc in exceptions:
self.assertEqual(type(exc), MSSQLDatabaseException)
@mark_slow
def testErrorSprocThreadedUse(self):
with error_sproc.create():
results, exceptions = self.run_threads(
num=5,
thread_class=_SprocTestingErrorThread)
self.assertEqual(len(exceptions), 5)
for exc in exceptions:
self.assertEqual(type(exc), MSSQLDatabaseException)
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ThreadedTests))
if __name__ == '__main__':
unittest.main()
|