1import unittest 2from weakref import WeakSet 3import string 4from collections import UserString as ustr 5from collections.abc import Set, MutableSet 6import gc 7import contextlib 8from test import support 9 10 11class Foo: 12 pass 13 14class RefCycle: 15 def __init__(self): 16 self.cycle = self 17 18 19class TestWeakSet(unittest.TestCase): 20 21 def setUp(self): 22 # need to keep references to them 23 self.items = [ustr(c) for c in ('a', 'b', 'c')] 24 self.items2 = [ustr(c) for c in ('x', 'y', 'z')] 25 self.ab_items = [ustr(c) for c in 'ab'] 26 self.abcde_items = [ustr(c) for c in 'abcde'] 27 self.def_items = [ustr(c) for c in 'def'] 28 self.ab_weakset = WeakSet(self.ab_items) 29 self.abcde_weakset = WeakSet(self.abcde_items) 30 self.def_weakset = WeakSet(self.def_items) 31 self.letters = [ustr(c) for c in string.ascii_letters] 32 self.s = WeakSet(self.items) 33 self.d = dict.fromkeys(self.items) 34 self.obj = ustr('F') 35 self.fs = WeakSet([self.obj]) 36 37 def test_methods(self): 38 weaksetmethods = dir(WeakSet) 39 for method in dir(set): 40 if method == 'test_c_api' or method.startswith('_'): 41 continue 42 self.assertIn(method, weaksetmethods, 43 "WeakSet missing method " + method) 44 45 def test_new_or_init(self): 46 self.assertRaises(TypeError, WeakSet, [], 2) 47 48 def test_len(self): 49 self.assertEqual(len(self.s), len(self.d)) 50 self.assertEqual(len(self.fs), 1) 51 del self.obj 52 support.gc_collect() # For PyPy or other GCs. 53 self.assertEqual(len(self.fs), 0) 54 55 def test_contains(self): 56 for c in self.letters: 57 self.assertEqual(c in self.s, c in self.d) 58 # 1 is not weakref'able, but that TypeError is caught by __contains__ 59 self.assertNotIn(1, self.s) 60 self.assertIn(self.obj, self.fs) 61 del self.obj 62 support.gc_collect() # For PyPy or other GCs. 63 self.assertNotIn(ustr('F'), self.fs) 64 65 def test_union(self): 66 u = self.s.union(self.items2) 67 for c in self.letters: 68 self.assertEqual(c in u, c in self.d or c in self.items2) 69 self.assertEqual(self.s, WeakSet(self.items)) 70 self.assertEqual(type(u), WeakSet) 71 self.assertRaises(TypeError, self.s.union, [[]]) 72 for C in set, frozenset, dict.fromkeys, list, tuple: 73 x = WeakSet(self.items + self.items2) 74 c = C(self.items2) 75 self.assertEqual(self.s.union(c), x) 76 del c 77 self.assertEqual(len(u), len(self.items) + len(self.items2)) 78 self.items2.pop() 79 gc.collect() 80 self.assertEqual(len(u), len(self.items) + len(self.items2)) 81 82 def test_or(self): 83 i = self.s.union(self.items2) 84 self.assertEqual(self.s | set(self.items2), i) 85 self.assertEqual(self.s | frozenset(self.items2), i) 86 87 def test_intersection(self): 88 s = WeakSet(self.letters) 89 i = s.intersection(self.items2) 90 for c in self.letters: 91 self.assertEqual(c in i, c in self.items2 and c in self.letters) 92 self.assertEqual(s, WeakSet(self.letters)) 93 self.assertEqual(type(i), WeakSet) 94 for C in set, frozenset, dict.fromkeys, list, tuple: 95 x = WeakSet([]) 96 self.assertEqual(i.intersection(C(self.items)), x) 97 self.assertEqual(len(i), len(self.items2)) 98 self.items2.pop() 99 gc.collect() 100 self.assertEqual(len(i), len(self.items2)) 101 102 def test_isdisjoint(self): 103 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) 104 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters))) 105 106 def test_and(self): 107 i = self.s.intersection(self.items2) 108 self.assertEqual(self.s & set(self.items2), i) 109 self.assertEqual(self.s & frozenset(self.items2), i) 110 111 def test_difference(self): 112 i = self.s.difference(self.items2) 113 for c in self.letters: 114 self.assertEqual(c in i, c in self.d and c not in self.items2) 115 self.assertEqual(self.s, WeakSet(self.items)) 116 self.assertEqual(type(i), WeakSet) 117 self.assertRaises(TypeError, self.s.difference, [[]]) 118 119 def test_sub(self): 120 i = self.s.difference(self.items2) 121 self.assertEqual(self.s - set(self.items2), i) 122 self.assertEqual(self.s - frozenset(self.items2), i) 123 124 def test_symmetric_difference(self): 125 i = self.s.symmetric_difference(self.items2) 126 for c in self.letters: 127 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2)) 128 self.assertEqual(self.s, WeakSet(self.items)) 129 self.assertEqual(type(i), WeakSet) 130 self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) 131 self.assertEqual(len(i), len(self.items) + len(self.items2)) 132 self.items2.pop() 133 gc.collect() 134 self.assertEqual(len(i), len(self.items) + len(self.items2)) 135 136 def test_xor(self): 137 i = self.s.symmetric_difference(self.items2) 138 self.assertEqual(self.s ^ set(self.items2), i) 139 self.assertEqual(self.s ^ frozenset(self.items2), i) 140 141 def test_sub_and_super(self): 142 self.assertTrue(self.ab_weakset <= self.abcde_weakset) 143 self.assertTrue(self.abcde_weakset <= self.abcde_weakset) 144 self.assertTrue(self.abcde_weakset >= self.ab_weakset) 145 self.assertFalse(self.abcde_weakset <= self.def_weakset) 146 self.assertFalse(self.abcde_weakset >= self.def_weakset) 147 self.assertTrue(set('a').issubset('abc')) 148 self.assertTrue(set('abc').issuperset('a')) 149 self.assertFalse(set('a').issubset('cbs')) 150 self.assertFalse(set('cbs').issuperset('a')) 151 152 def test_lt(self): 153 self.assertTrue(self.ab_weakset < self.abcde_weakset) 154 self.assertFalse(self.abcde_weakset < self.def_weakset) 155 self.assertFalse(self.ab_weakset < self.ab_weakset) 156 self.assertFalse(WeakSet() < WeakSet()) 157 158 def test_gt(self): 159 self.assertTrue(self.abcde_weakset > self.ab_weakset) 160 self.assertFalse(self.abcde_weakset > self.def_weakset) 161 self.assertFalse(self.ab_weakset > self.ab_weakset) 162 self.assertFalse(WeakSet() > WeakSet()) 163 164 def test_gc(self): 165 # Create a nest of cycles to exercise overall ref count check 166 s = WeakSet(Foo() for i in range(1000)) 167 for elem in s: 168 elem.cycle = s 169 elem.sub = elem 170 elem.set = WeakSet([elem]) 171 172 def test_subclass_with_custom_hash(self): 173 # Bug #1257731 174 class H(WeakSet): 175 def __hash__(self): 176 return int(id(self) & 0x7fffffff) 177 s=H() 178 f=set() 179 f.add(s) 180 self.assertIn(s, f) 181 f.remove(s) 182 f.add(s) 183 f.discard(s) 184 185 def test_init(self): 186 s = WeakSet() 187 s.__init__(self.items) 188 self.assertEqual(s, self.s) 189 s.__init__(self.items2) 190 self.assertEqual(s, WeakSet(self.items2)) 191 self.assertRaises(TypeError, s.__init__, s, 2); 192 self.assertRaises(TypeError, s.__init__, 1); 193 194 def test_constructor_identity(self): 195 s = WeakSet(self.items) 196 t = WeakSet(s) 197 self.assertNotEqual(id(s), id(t)) 198 199 def test_hash(self): 200 self.assertRaises(TypeError, hash, self.s) 201 202 def test_clear(self): 203 self.s.clear() 204 self.assertEqual(self.s, WeakSet([])) 205 self.assertEqual(len(self.s), 0) 206 207 def test_copy(self): 208 dup = self.s.copy() 209 self.assertEqual(self.s, dup) 210 self.assertNotEqual(id(self.s), id(dup)) 211 212 def test_add(self): 213 x = ustr('Q') 214 self.s.add(x) 215 self.assertIn(x, self.s) 216 dup = self.s.copy() 217 self.s.add(x) 218 self.assertEqual(self.s, dup) 219 self.assertRaises(TypeError, self.s.add, []) 220 self.fs.add(Foo()) 221 support.gc_collect() # For PyPy or other GCs. 222 self.assertTrue(len(self.fs) == 1) 223 self.fs.add(self.obj) 224 self.assertTrue(len(self.fs) == 1) 225 226 def test_remove(self): 227 x = ustr('a') 228 self.s.remove(x) 229 self.assertNotIn(x, self.s) 230 self.assertRaises(KeyError, self.s.remove, x) 231 self.assertRaises(TypeError, self.s.remove, []) 232 233 def test_discard(self): 234 a, q = ustr('a'), ustr('Q') 235 self.s.discard(a) 236 self.assertNotIn(a, self.s) 237 self.s.discard(q) 238 self.assertRaises(TypeError, self.s.discard, []) 239 240 def test_pop(self): 241 for i in range(len(self.s)): 242 elem = self.s.pop() 243 self.assertNotIn(elem, self.s) 244 self.assertRaises(KeyError, self.s.pop) 245 246 def test_update(self): 247 retval = self.s.update(self.items2) 248 self.assertEqual(retval, None) 249 for c in (self.items + self.items2): 250 self.assertIn(c, self.s) 251 self.assertRaises(TypeError, self.s.update, [[]]) 252 253 def test_update_set(self): 254 self.s.update(set(self.items2)) 255 for c in (self.items + self.items2): 256 self.assertIn(c, self.s) 257 258 def test_ior(self): 259 self.s |= set(self.items2) 260 for c in (self.items + self.items2): 261 self.assertIn(c, self.s) 262 263 def test_intersection_update(self): 264 retval = self.s.intersection_update(self.items2) 265 self.assertEqual(retval, None) 266 for c in (self.items + self.items2): 267 if c in self.items2 and c in self.items: 268 self.assertIn(c, self.s) 269 else: 270 self.assertNotIn(c, self.s) 271 self.assertRaises(TypeError, self.s.intersection_update, [[]]) 272 273 def test_iand(self): 274 self.s &= set(self.items2) 275 for c in (self.items + self.items2): 276 if c in self.items2 and c in self.items: 277 self.assertIn(c, self.s) 278 else: 279 self.assertNotIn(c, self.s) 280 281 def test_difference_update(self): 282 retval = self.s.difference_update(self.items2) 283 self.assertEqual(retval, None) 284 for c in (self.items + self.items2): 285 if c in self.items and c not in self.items2: 286 self.assertIn(c, self.s) 287 else: 288 self.assertNotIn(c, self.s) 289 self.assertRaises(TypeError, self.s.difference_update, [[]]) 290 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 291 292 def test_isub(self): 293 self.s -= set(self.items2) 294 for c in (self.items + self.items2): 295 if c in self.items and c not in self.items2: 296 self.assertIn(c, self.s) 297 else: 298 self.assertNotIn(c, self.s) 299 300 def test_symmetric_difference_update(self): 301 retval = self.s.symmetric_difference_update(self.items2) 302 self.assertEqual(retval, None) 303 for c in (self.items + self.items2): 304 if (c in self.items) ^ (c in self.items2): 305 self.assertIn(c, self.s) 306 else: 307 self.assertNotIn(c, self.s) 308 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 309 310 def test_ixor(self): 311 self.s ^= set(self.items2) 312 for c in (self.items + self.items2): 313 if (c in self.items) ^ (c in self.items2): 314 self.assertIn(c, self.s) 315 else: 316 self.assertNotIn(c, self.s) 317 318 def test_inplace_on_self(self): 319 t = self.s.copy() 320 t |= t 321 self.assertEqual(t, self.s) 322 t &= t 323 self.assertEqual(t, self.s) 324 t -= t 325 self.assertEqual(t, WeakSet()) 326 t = self.s.copy() 327 t ^= t 328 self.assertEqual(t, WeakSet()) 329 330 def test_eq(self): 331 # issue 5964 332 self.assertTrue(self.s == self.s) 333 self.assertTrue(self.s == WeakSet(self.items)) 334 self.assertFalse(self.s == set(self.items)) 335 self.assertFalse(self.s == list(self.items)) 336 self.assertFalse(self.s == tuple(self.items)) 337 self.assertFalse(self.s == WeakSet([Foo])) 338 self.assertFalse(self.s == 1) 339 340 def test_ne(self): 341 self.assertTrue(self.s != set(self.items)) 342 s1 = WeakSet() 343 s2 = WeakSet() 344 self.assertFalse(s1 != s2) 345 346 def test_weak_destroy_while_iterating(self): 347 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 348 # Create new items to be sure no-one else holds a reference 349 items = [ustr(c) for c in ('a', 'b', 'c')] 350 s = WeakSet(items) 351 it = iter(s) 352 next(it) # Trigger internal iteration 353 # Destroy an item 354 del items[-1] 355 gc.collect() # just in case 356 # We have removed either the first consumed items, or another one 357 self.assertIn(len(list(it)), [len(items), len(items) - 1]) 358 del it 359 # The removal has been committed 360 self.assertEqual(len(s), len(items)) 361 362 def test_weak_destroy_and_mutate_while_iterating(self): 363 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 364 items = [ustr(c) for c in string.ascii_letters] 365 s = WeakSet(items) 366 @contextlib.contextmanager 367 def testcontext(): 368 try: 369 it = iter(s) 370 # Start iterator 371 yielded = ustr(str(next(it))) 372 # Schedule an item for removal and recreate it 373 u = ustr(str(items.pop())) 374 if yielded == u: 375 # The iterator still has a reference to the removed item, 376 # advance it (issue #20006). 377 next(it) 378 gc.collect() # just in case 379 yield u 380 finally: 381 it = None # should commit all removals 382 383 with testcontext() as u: 384 self.assertNotIn(u, s) 385 with testcontext() as u: 386 self.assertRaises(KeyError, s.remove, u) 387 self.assertNotIn(u, s) 388 with testcontext() as u: 389 s.add(u) 390 self.assertIn(u, s) 391 t = s.copy() 392 with testcontext() as u: 393 s.update(t) 394 self.assertEqual(len(s), len(t)) 395 with testcontext() as u: 396 s.clear() 397 self.assertEqual(len(s), 0) 398 399 def test_len_cycles(self): 400 N = 20 401 items = [RefCycle() for i in range(N)] 402 s = WeakSet(items) 403 del items 404 it = iter(s) 405 try: 406 next(it) 407 except StopIteration: 408 pass 409 gc.collect() 410 n1 = len(s) 411 del it 412 gc.collect() 413 gc.collect() # For PyPy or other GCs. 414 n2 = len(s) 415 # one item may be kept alive inside the iterator 416 self.assertIn(n1, (0, 1)) 417 self.assertEqual(n2, 0) 418 419 def test_len_race(self): 420 # Extended sanity checks for len() in the face of cyclic collection 421 self.addCleanup(gc.set_threshold, *gc.get_threshold()) 422 for th in range(1, 100): 423 N = 20 424 gc.collect(0) 425 gc.set_threshold(th, th, th) 426 items = [RefCycle() for i in range(N)] 427 s = WeakSet(items) 428 del items 429 # All items will be collected at next garbage collection pass 430 it = iter(s) 431 try: 432 next(it) 433 except StopIteration: 434 pass 435 n1 = len(s) 436 del it 437 n2 = len(s) 438 self.assertGreaterEqual(n1, 0) 439 self.assertLessEqual(n1, N) 440 self.assertGreaterEqual(n2, 0) 441 self.assertLessEqual(n2, n1) 442 443 def test_repr(self): 444 assert repr(self.s) == repr(self.s.data) 445 446 def test_abc(self): 447 self.assertIsInstance(self.s, Set) 448 self.assertIsInstance(self.s, MutableSet) 449 450 451if __name__ == "__main__": 452 unittest.main() 453