1import gc 2import time 3import unittest 4import weakref 5 6from ast import Or 7from functools import partial 8from threading import Thread 9from unittest import TestCase 10 11try: 12 import _testcapi 13except ImportError: 14 _testcapi = None 15 16from test.support import threading_helper 17 18 19@threading_helper.requires_working_threading() 20class TestDict(TestCase): 21 def test_racing_creation_shared_keys(self): 22 """Verify that creating dictionaries is thread safe when we 23 have a type with shared keys""" 24 class C(int): 25 pass 26 27 self.racing_creation(C) 28 29 def test_racing_creation_no_shared_keys(self): 30 """Verify that creating dictionaries is thread safe when we 31 have a type with an ordinary dict""" 32 self.racing_creation(Or) 33 34 def test_racing_creation_inline_values_invalid(self): 35 """Verify that re-creating a dict after we have invalid inline values 36 is thread safe""" 37 class C: 38 pass 39 40 def make_obj(): 41 a = C() 42 # Make object, make inline values invalid, and then delete dict 43 a.__dict__ = {} 44 del a.__dict__ 45 return a 46 47 self.racing_creation(make_obj) 48 49 def test_racing_creation_nonmanaged_dict(self): 50 """Verify that explicit creation of an unmanaged dict is thread safe 51 outside of the normal attribute setting code path""" 52 def make_obj(): 53 def f(): pass 54 return f 55 56 def set(func, name, val): 57 # Force creation of the dict via PyObject_GenericGetDict 58 func.__dict__[name] = val 59 60 self.racing_creation(make_obj, set) 61 62 def racing_creation(self, cls, set=setattr): 63 objects = [] 64 processed = [] 65 66 OBJECT_COUNT = 100 67 THREAD_COUNT = 10 68 CUR = 0 69 70 for i in range(OBJECT_COUNT): 71 objects.append(cls()) 72 73 def writer_func(name): 74 last = -1 75 while True: 76 if CUR == last: 77 continue 78 elif CUR == OBJECT_COUNT: 79 break 80 81 obj = objects[CUR] 82 set(obj, name, name) 83 last = CUR 84 processed.append(name) 85 86 writers = [] 87 for x in range(THREAD_COUNT): 88 writer = Thread(target=partial(writer_func, f"a{x:02}")) 89 writers.append(writer) 90 writer.start() 91 92 for i in range(OBJECT_COUNT): 93 CUR = i 94 while len(processed) != THREAD_COUNT: 95 time.sleep(0.001) 96 processed.clear() 97 98 CUR = OBJECT_COUNT 99 100 for writer in writers: 101 writer.join() 102 103 for obj_idx, obj in enumerate(objects): 104 assert ( 105 len(obj.__dict__) == THREAD_COUNT 106 ), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}" 107 for i in range(THREAD_COUNT): 108 assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}" 109 110 def test_racing_set_dict(self): 111 """Races assigning to __dict__ should be thread safe""" 112 113 def f(): pass 114 l = [] 115 THREAD_COUNT = 10 116 class MyDict(dict): pass 117 118 def writer_func(l): 119 for i in range(1000): 120 d = MyDict() 121 l.append(weakref.ref(d)) 122 f.__dict__ = d 123 124 lists = [] 125 writers = [] 126 for x in range(THREAD_COUNT): 127 thread_list = [] 128 lists.append(thread_list) 129 writer = Thread(target=partial(writer_func, thread_list)) 130 writers.append(writer) 131 132 for writer in writers: 133 writer.start() 134 135 for writer in writers: 136 writer.join() 137 138 f.__dict__ = {} 139 gc.collect() 140 141 for thread_list in lists: 142 for ref in thread_list: 143 self.assertIsNone(ref()) 144 145 @unittest.skipIf(_testcapi is None, 'need _testcapi module') 146 def test_dict_version(self): 147 dict_version = _testcapi.dict_version 148 THREAD_COUNT = 10 149 DICT_COUNT = 10000 150 lists = [] 151 writers = [] 152 153 def writer_func(thread_list): 154 for i in range(DICT_COUNT): 155 thread_list.append(dict_version({})) 156 157 for x in range(THREAD_COUNT): 158 thread_list = [] 159 lists.append(thread_list) 160 writer = Thread(target=partial(writer_func, thread_list)) 161 writers.append(writer) 162 163 for writer in writers: 164 writer.start() 165 166 for writer in writers: 167 writer.join() 168 169 total_len = 0 170 values = set() 171 for thread_list in lists: 172 for v in thread_list: 173 if v in values: 174 print('dup', v, (v/4096)%256) 175 values.add(v) 176 total_len += len(thread_list) 177 versions = set(dict_version for thread_list in lists for dict_version in thread_list) 178 self.assertEqual(len(versions), THREAD_COUNT*DICT_COUNT) 179 180 181if __name__ == "__main__": 182 unittest.main() 183