• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import gc
2import time
3import unittest
4import weakref
5
6from ast import Or
7from functools import partial
8from threading import Thread
9from unittest import TestCase
10
11try:
12    import _testcapi
13except ImportError:
14    _testcapi = None
15
16from test.support import threading_helper
17
18
19@threading_helper.requires_working_threading()
20class TestDict(TestCase):
21    def test_racing_creation_shared_keys(self):
22        """Verify that creating dictionaries is thread safe when we
23        have a type with shared keys"""
24        class C(int):
25            pass
26
27        self.racing_creation(C)
28
29    def test_racing_creation_no_shared_keys(self):
30        """Verify that creating dictionaries is thread safe when we
31        have a type with an ordinary dict"""
32        self.racing_creation(Or)
33
34    def test_racing_creation_inline_values_invalid(self):
35        """Verify that re-creating a dict after we have invalid inline values
36        is thread safe"""
37        class C:
38            pass
39
40        def make_obj():
41            a = C()
42            # Make object, make inline values invalid, and then delete dict
43            a.__dict__ = {}
44            del a.__dict__
45            return a
46
47        self.racing_creation(make_obj)
48
49    def test_racing_creation_nonmanaged_dict(self):
50        """Verify that explicit creation of an unmanaged dict is thread safe
51        outside of the normal attribute setting code path"""
52        def make_obj():
53            def f(): pass
54            return f
55
56        def set(func, name, val):
57            # Force creation of the dict via PyObject_GenericGetDict
58            func.__dict__[name] = val
59
60        self.racing_creation(make_obj, set)
61
62    def racing_creation(self, cls, set=setattr):
63        objects = []
64        processed = []
65
66        OBJECT_COUNT = 100
67        THREAD_COUNT = 10
68        CUR = 0
69
70        for i in range(OBJECT_COUNT):
71            objects.append(cls())
72
73        def writer_func(name):
74            last = -1
75            while True:
76                if CUR == last:
77                    continue
78                elif CUR == OBJECT_COUNT:
79                    break
80
81                obj = objects[CUR]
82                set(obj, name, name)
83                last = CUR
84                processed.append(name)
85
86        writers = []
87        for x in range(THREAD_COUNT):
88            writer = Thread(target=partial(writer_func, f"a{x:02}"))
89            writers.append(writer)
90            writer.start()
91
92        for i in range(OBJECT_COUNT):
93            CUR = i
94            while len(processed) != THREAD_COUNT:
95                time.sleep(0.001)
96            processed.clear()
97
98        CUR = OBJECT_COUNT
99
100        for writer in writers:
101            writer.join()
102
103        for obj_idx, obj in enumerate(objects):
104            assert (
105                len(obj.__dict__) == THREAD_COUNT
106            ), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
107            for i in range(THREAD_COUNT):
108                assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
109
110    def test_racing_set_dict(self):
111        """Races assigning to __dict__ should be thread safe"""
112
113        def f(): pass
114        l = []
115        THREAD_COUNT = 10
116        class MyDict(dict): pass
117
118        def writer_func(l):
119            for i in range(1000):
120                d = MyDict()
121                l.append(weakref.ref(d))
122                f.__dict__ = d
123
124        lists = []
125        writers = []
126        for x in range(THREAD_COUNT):
127            thread_list = []
128            lists.append(thread_list)
129            writer = Thread(target=partial(writer_func, thread_list))
130            writers.append(writer)
131
132        for writer in writers:
133            writer.start()
134
135        for writer in writers:
136            writer.join()
137
138        f.__dict__ = {}
139        gc.collect()
140
141        for thread_list in lists:
142            for ref in thread_list:
143                self.assertIsNone(ref())
144
145    @unittest.skipIf(_testcapi is None, 'need _testcapi module')
146    def test_dict_version(self):
147        dict_version = _testcapi.dict_version
148        THREAD_COUNT = 10
149        DICT_COUNT = 10000
150        lists = []
151        writers = []
152
153        def writer_func(thread_list):
154            for i in range(DICT_COUNT):
155                thread_list.append(dict_version({}))
156
157        for x in range(THREAD_COUNT):
158            thread_list = []
159            lists.append(thread_list)
160            writer = Thread(target=partial(writer_func, thread_list))
161            writers.append(writer)
162
163        for writer in writers:
164            writer.start()
165
166        for writer in writers:
167            writer.join()
168
169        total_len = 0
170        values = set()
171        for thread_list in lists:
172            for v in thread_list:
173                if v in values:
174                    print('dup', v, (v/4096)%256)
175                values.add(v)
176            total_len += len(thread_list)
177        versions = set(dict_version for thread_list in lists for dict_version in thread_list)
178        self.assertEqual(len(versions), THREAD_COUNT*DICT_COUNT)
179
180
181if __name__ == "__main__":
182    unittest.main()
183