File: test_dict.py

package info (click to toggle)
python3.13 3.13.6-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 121,256 kB
  • sloc: python: 703,743; ansic: 653,888; xml: 31,250; sh: 5,844; cpp: 4,326; makefile: 1,981; objc: 787; lisp: 502; javascript: 213; asm: 75; csh: 12
file content (201 lines) | stat: -rw-r--r-- 5,572 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gc
import time
import unittest
import weakref

from ast import Or
from functools import partial
from threading import Thread, Barrier
from unittest import TestCase

try:
    import _testcapi
except ImportError:
    _testcapi = None

from test.support import threading_helper


@threading_helper.requires_working_threading()
class TestDict(TestCase):
    def test_racing_creation_shared_keys(self):
        """Verify that creating dictionaries is thread safe when we
        have a type with shared keys"""
        class C(int):
            pass

        self.racing_creation(C)

    def test_racing_creation_no_shared_keys(self):
        """Verify that creating dictionaries is thread safe when we
        have a type with an ordinary dict"""
        self.racing_creation(Or)

    def test_racing_creation_inline_values_invalid(self):
        """Verify that re-creating a dict after we have invalid inline values
        is thread safe"""
        class C:
            pass

        def make_obj():
            a = C()
            # Make object, make inline values invalid, and then delete dict
            a.__dict__ = {}
            del a.__dict__
            return a

        self.racing_creation(make_obj)

    def test_racing_creation_nonmanaged_dict(self):
        """Verify that explicit creation of an unmanaged dict is thread safe
        outside of the normal attribute setting code path"""
        def make_obj():
            def f(): pass
            return f

        def set(func, name, val):
            # Force creation of the dict via PyObject_GenericGetDict
            func.__dict__[name] = val

        self.racing_creation(make_obj, set)

    def racing_creation(self, cls, set=setattr):
        objects = []
        processed = []

        OBJECT_COUNT = 100
        THREAD_COUNT = 10
        CUR = 0

        for i in range(OBJECT_COUNT):
            objects.append(cls())

        def writer_func(name):
            last = -1
            while True:
                if CUR == last:
                    continue
                elif CUR == OBJECT_COUNT:
                    break

                obj = objects[CUR]
                set(obj, name, name)
                last = CUR
                processed.append(name)

        writers = []
        for x in range(THREAD_COUNT):
            writer = Thread(target=partial(writer_func, f"a{x:02}"))
            writers.append(writer)
            writer.start()

        for i in range(OBJECT_COUNT):
            CUR = i
            while len(processed) != THREAD_COUNT:
                time.sleep(0.001)
            processed.clear()

        CUR = OBJECT_COUNT

        for writer in writers:
            writer.join()

        for obj_idx, obj in enumerate(objects):
            assert (
                len(obj.__dict__) == THREAD_COUNT
            ), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
            for i in range(THREAD_COUNT):
                assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"

    def test_racing_set_dict(self):
        """Races assigning to __dict__ should be thread safe"""

        def f(): pass
        l = []
        THREAD_COUNT = 10
        class MyDict(dict): pass

        def writer_func(l):
            for i in range(1000):
                d = MyDict()
                l.append(weakref.ref(d))
                f.__dict__ = d

        lists = []
        writers = []
        for x in range(THREAD_COUNT):
            thread_list = []
            lists.append(thread_list)
            writer = Thread(target=partial(writer_func, thread_list))
            writers.append(writer)

        for writer in writers:
            writer.start()

        for writer in writers:
            writer.join()

        f.__dict__ = {}
        gc.collect()

        for thread_list in lists:
            for ref in thread_list:
                self.assertIsNone(ref())

    def test_getattr_setattr(self):
        NUM_THREADS = 10
        b = Barrier(NUM_THREADS)

        def closure(b, c):
            b.wait()
            for i in range(10):
                getattr(c, f'attr_{i}', None)
                setattr(c, f'attr_{i}', 99)

        class MyClass:
            pass

        o = MyClass()
        threads = [Thread(target=closure, args=(b, o))
                for _ in range(NUM_THREADS)]
        with threading_helper.start_threads(threads):
            pass

    @unittest.skipIf(_testcapi is None, 'need _testcapi module')
    def test_dict_version(self):
        dict_version = _testcapi.dict_version
        THREAD_COUNT = 10
        DICT_COUNT = 10000
        lists = []
        writers = []

        def writer_func(thread_list):
            for i in range(DICT_COUNT):
                thread_list.append(dict_version({}))

        for x in range(THREAD_COUNT):
            thread_list = []
            lists.append(thread_list)
            writer = Thread(target=partial(writer_func, thread_list))
            writers.append(writer)

        for writer in writers:
            writer.start()

        for writer in writers:
            writer.join()

        total_len = 0
        values = set()
        for thread_list in lists:
            for v in thread_list:
                if v in values:
                    print('dup', v, (v/4096)%256)
                values.add(v)
            total_len += len(thread_list)
        versions = set(dict_version for thread_list in lists for dict_version in thread_list)
        self.assertEqual(len(versions), THREAD_COUNT*DICT_COUNT)


if __name__ == "__main__":
    unittest.main()