File: test_threaded.py

package info (click to toggle)
pymssql 2.1.4%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 952 kB
  • sloc: python: 2,872; sh: 240; makefile: 148; ansic: 7
file content (111 lines) | stat: -rw-r--r-- 3,066 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sys
import threading
import time
import unittest

from _mssql import MSSQLDatabaseException

from .helpers import mssqlconn, StoredProc, mark_slow


error_sproc = StoredProc(
    "pymssqlErrorThreadTest",
    args=(),
    body="SELECT unknown_column FROM unknown_table")


class _TestingThread(threading.Thread):
    def __init__(self):
        super(_TestingThread, self).__init__()
        self.results = []
        self.exc = None

    def run(self):
        try:
            with mssqlconn() as mssql:
                for i in range(0, 1000):
                    num = mssql.execute_scalar('SELECT %d', (i,))
                    assert num == i
                    self.results.append(num)
        except Exception as exc:
            self.exc = exc


class _TestingErrorThread(_TestingThread):
    def run(self):
        try:
            with mssqlconn() as mssql:
                mssql.execute_query('SELECT unknown_column')
        except Exception as exc:
            self.exc = exc


class _SprocTestingErrorThread(_TestingThread):
    def run(self):
        try:
            with mssqlconn() as mssql:
                error_sproc.execute(mssql=mssql)
        except Exception as exc:
            self.exc = exc


class ThreadedTests(unittest.TestCase):
    def run_threads(self, num, thread_class):
        threads = [thread_class() for _ in range(num)]
        for thread in threads:
            thread.start()

        results = []
        exceptions = []

        while len(threads) > 0:
            sys.stdout.write(".")
            sys.stdout.flush()
            for thread in threads:
                if not thread.is_alive():
                    threads.remove(thread)
                if thread.results:
                    results.append(thread.results)
                if thread.exc:
                    exceptions.append(thread.exc)
            time.sleep(5)

        sys.stdout.write(" ")
        sys.stdout.flush()

        return results, exceptions

    @mark_slow
    def testThreadedUse(self):
        results, exceptions = self.run_threads(
            num=50,
            thread_class=_TestingThread)
        self.assertEqual(len(exceptions), 0)
        for result in results:
            self.assertEqual(result, list(range(0, 1000)))

    @mark_slow
    def testErrorThreadedUse(self):
        results, exceptions = self.run_threads(
            num=2,
            thread_class=_TestingErrorThread)
        self.assertEqual(len(exceptions), 2)
        for exc in exceptions:
            self.assertEqual(type(exc), MSSQLDatabaseException)

    @mark_slow
    def testErrorSprocThreadedUse(self):
        with error_sproc.create():
            results, exceptions = self.run_threads(
                num=5,
                thread_class=_SprocTestingErrorThread)
        self.assertEqual(len(exceptions), 5)
        for exc in exceptions:
            self.assertEqual(type(exc), MSSQLDatabaseException)


suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ThreadedTests))

if __name__ == '__main__':
    unittest.main()