1import unittest 2from test import test_support 3from weakref import proxy, ref, WeakSet 4import operator 5import copy 6import string 7import os 8from random import randrange, shuffle 9import sys 10import warnings 11import collections 12import gc 13import contextlib 14 15 16class Foo: 17 pass 18 19class SomeClass(object): 20 def __init__(self, value): 21 self.value = value 22 def __eq__(self, other): 23 if type(other) != type(self): 24 return False 25 return other.value == self.value 26 27 def __ne__(self, other): 28 return not self.__eq__(other) 29 30 def __hash__(self): 31 return hash((SomeClass, self.value)) 32 33class RefCycle(object): 34 def __init__(self): 35 self.cycle = self 36 37class TestWeakSet(unittest.TestCase): 38 39 def setUp(self): 40 # need to keep references to them 41 self.items = [SomeClass(c) for c in ('a', 'b', 'c')] 42 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')] 43 self.letters = [SomeClass(c) for c in string.ascii_letters] 44 self.ab_items = [SomeClass(c) for c in 'ab'] 45 self.abcde_items = [SomeClass(c) for c in 'abcde'] 46 self.def_items = [SomeClass(c) for c in 'def'] 47 self.ab_weakset = WeakSet(self.ab_items) 48 self.abcde_weakset = WeakSet(self.abcde_items) 49 self.def_weakset = WeakSet(self.def_items) 50 self.s = WeakSet(self.items) 51 self.d = dict.fromkeys(self.items) 52 self.obj = SomeClass('F') 53 self.fs = WeakSet([self.obj]) 54 55 def test_methods(self): 56 weaksetmethods = dir(WeakSet) 57 for method in dir(set): 58 if method == 'test_c_api' or method.startswith('_'): 59 continue 60 self.assertIn(method, weaksetmethods, 61 "WeakSet missing method " + method) 62 63 def test_new_or_init(self): 64 self.assertRaises(TypeError, WeakSet, [], 2) 65 66 def test_len(self): 67 self.assertEqual(len(self.s), len(self.d)) 68 self.assertEqual(len(self.fs), 1) 69 del self.obj 70 self.assertEqual(len(self.fs), 0) 71 72 def test_contains(self): 73 for c in self.letters: 74 self.assertEqual(c in self.s, c in self.d) 75 # 1 is not weakref'able, but that TypeError is caught by __contains__ 76 self.assertNotIn(1, self.s) 77 self.assertIn(self.obj, self.fs) 78 del self.obj 79 self.assertNotIn(SomeClass('F'), self.fs) 80 81 def test_union(self): 82 u = self.s.union(self.items2) 83 for c in self.letters: 84 self.assertEqual(c in u, c in self.d or c in self.items2) 85 self.assertEqual(self.s, WeakSet(self.items)) 86 self.assertEqual(type(u), WeakSet) 87 self.assertRaises(TypeError, self.s.union, [[]]) 88 for C in set, frozenset, dict.fromkeys, list, tuple: 89 x = WeakSet(self.items + self.items2) 90 c = C(self.items2) 91 self.assertEqual(self.s.union(c), x) 92 del c 93 self.assertEqual(len(u), len(self.items) + len(self.items2)) 94 self.items2.pop() 95 gc.collect() 96 self.assertEqual(len(u), len(self.items) + len(self.items2)) 97 98 def test_or(self): 99 i = self.s.union(self.items2) 100 self.assertEqual(self.s | set(self.items2), i) 101 self.assertEqual(self.s | frozenset(self.items2), i) 102 103 def test_intersection(self): 104 s = WeakSet(self.letters) 105 i = s.intersection(self.items2) 106 for c in self.letters: 107 self.assertEqual(c in i, c in self.items2 and c in self.letters) 108 self.assertEqual(s, WeakSet(self.letters)) 109 self.assertEqual(type(i), WeakSet) 110 for C in set, frozenset, dict.fromkeys, list, tuple: 111 x = WeakSet([]) 112 self.assertEqual(i.intersection(C(self.items)), x) 113 self.assertEqual(len(i), len(self.items2)) 114 self.items2.pop() 115 gc.collect() 116 self.assertEqual(len(i), len(self.items2)) 117 118 def test_isdisjoint(self): 119 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) 120 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters))) 121 122 def test_and(self): 123 i = self.s.intersection(self.items2) 124 self.assertEqual(self.s & set(self.items2), i) 125 self.assertEqual(self.s & frozenset(self.items2), i) 126 127 def test_difference(self): 128 i = self.s.difference(self.items2) 129 for c in self.letters: 130 self.assertEqual(c in i, c in self.d and c not in self.items2) 131 self.assertEqual(self.s, WeakSet(self.items)) 132 self.assertEqual(type(i), WeakSet) 133 self.assertRaises(TypeError, self.s.difference, [[]]) 134 135 def test_sub(self): 136 i = self.s.difference(self.items2) 137 self.assertEqual(self.s - set(self.items2), i) 138 self.assertEqual(self.s - frozenset(self.items2), i) 139 140 def test_symmetric_difference(self): 141 i = self.s.symmetric_difference(self.items2) 142 for c in self.letters: 143 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2)) 144 self.assertEqual(self.s, WeakSet(self.items)) 145 self.assertEqual(type(i), WeakSet) 146 self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) 147 self.assertEqual(len(i), len(self.items) + len(self.items2)) 148 self.items2.pop() 149 gc.collect() 150 self.assertEqual(len(i), len(self.items) + len(self.items2)) 151 152 def test_xor(self): 153 i = self.s.symmetric_difference(self.items2) 154 self.assertEqual(self.s ^ set(self.items2), i) 155 self.assertEqual(self.s ^ frozenset(self.items2), i) 156 157 def test_sub_and_super(self): 158 self.assertTrue(self.ab_weakset <= self.abcde_weakset) 159 self.assertTrue(self.abcde_weakset <= self.abcde_weakset) 160 self.assertTrue(self.abcde_weakset >= self.ab_weakset) 161 self.assertFalse(self.abcde_weakset <= self.def_weakset) 162 self.assertFalse(self.abcde_weakset >= self.def_weakset) 163 self.assertTrue(set('a').issubset('abc')) 164 self.assertTrue(set('abc').issuperset('a')) 165 self.assertFalse(set('a').issubset('cbs')) 166 self.assertFalse(set('cbs').issuperset('a')) 167 168 def test_lt(self): 169 self.assertTrue(self.ab_weakset < self.abcde_weakset) 170 self.assertFalse(self.abcde_weakset < self.def_weakset) 171 self.assertFalse(self.ab_weakset < self.ab_weakset) 172 self.assertFalse(WeakSet() < WeakSet()) 173 174 def test_gt(self): 175 self.assertTrue(self.abcde_weakset > self.ab_weakset) 176 self.assertFalse(self.abcde_weakset > self.def_weakset) 177 self.assertFalse(self.ab_weakset > self.ab_weakset) 178 self.assertFalse(WeakSet() > WeakSet()) 179 180 def test_gc(self): 181 # Create a nest of cycles to exercise overall ref count check 182 s = WeakSet(Foo() for i in range(1000)) 183 for elem in s: 184 elem.cycle = s 185 elem.sub = elem 186 elem.set = WeakSet([elem]) 187 188 def test_subclass_with_custom_hash(self): 189 # Bug #1257731 190 class H(WeakSet): 191 def __hash__(self): 192 return int(id(self) & 0x7fffffff) 193 s=H() 194 f=set() 195 f.add(s) 196 self.assertIn(s, f) 197 f.remove(s) 198 f.add(s) 199 f.discard(s) 200 201 def test_init(self): 202 s = WeakSet() 203 s.__init__(self.items) 204 self.assertEqual(s, self.s) 205 s.__init__(self.items2) 206 self.assertEqual(s, WeakSet(self.items2)) 207 self.assertRaises(TypeError, s.__init__, s, 2); 208 self.assertRaises(TypeError, s.__init__, 1); 209 210 def test_constructor_identity(self): 211 s = WeakSet(self.items) 212 t = WeakSet(s) 213 self.assertNotEqual(id(s), id(t)) 214 215 def test_hash(self): 216 self.assertRaises(TypeError, hash, self.s) 217 218 def test_clear(self): 219 self.s.clear() 220 self.assertEqual(self.s, WeakSet([])) 221 self.assertEqual(len(self.s), 0) 222 223 def test_copy(self): 224 dup = self.s.copy() 225 self.assertEqual(self.s, dup) 226 self.assertNotEqual(id(self.s), id(dup)) 227 228 def test_add(self): 229 x = SomeClass('Q') 230 self.s.add(x) 231 self.assertIn(x, self.s) 232 dup = self.s.copy() 233 self.s.add(x) 234 self.assertEqual(self.s, dup) 235 self.assertRaises(TypeError, self.s.add, []) 236 self.fs.add(Foo()) 237 self.assertTrue(len(self.fs) == 1) 238 self.fs.add(self.obj) 239 self.assertTrue(len(self.fs) == 1) 240 241 def test_remove(self): 242 x = SomeClass('a') 243 self.s.remove(x) 244 self.assertNotIn(x, self.s) 245 self.assertRaises(KeyError, self.s.remove, x) 246 self.assertRaises(TypeError, self.s.remove, []) 247 248 def test_discard(self): 249 a, q = SomeClass('a'), SomeClass('Q') 250 self.s.discard(a) 251 self.assertNotIn(a, self.s) 252 self.s.discard(q) 253 self.assertRaises(TypeError, self.s.discard, []) 254 255 def test_pop(self): 256 for i in range(len(self.s)): 257 elem = self.s.pop() 258 self.assertNotIn(elem, self.s) 259 self.assertRaises(KeyError, self.s.pop) 260 261 def test_update(self): 262 retval = self.s.update(self.items2) 263 self.assertEqual(retval, None) 264 for c in (self.items + self.items2): 265 self.assertIn(c, self.s) 266 self.assertRaises(TypeError, self.s.update, [[]]) 267 268 def test_update_set(self): 269 self.s.update(set(self.items2)) 270 for c in (self.items + self.items2): 271 self.assertIn(c, self.s) 272 273 def test_ior(self): 274 self.s |= set(self.items2) 275 for c in (self.items + self.items2): 276 self.assertIn(c, self.s) 277 278 def test_intersection_update(self): 279 retval = self.s.intersection_update(self.items2) 280 self.assertEqual(retval, None) 281 for c in (self.items + self.items2): 282 if c in self.items2 and c in self.items: 283 self.assertIn(c, self.s) 284 else: 285 self.assertNotIn(c, self.s) 286 self.assertRaises(TypeError, self.s.intersection_update, [[]]) 287 288 def test_iand(self): 289 self.s &= set(self.items2) 290 for c in (self.items + self.items2): 291 if c in self.items2 and c in self.items: 292 self.assertIn(c, self.s) 293 else: 294 self.assertNotIn(c, self.s) 295 296 def test_difference_update(self): 297 retval = self.s.difference_update(self.items2) 298 self.assertEqual(retval, None) 299 for c in (self.items + self.items2): 300 if c in self.items and c not in self.items2: 301 self.assertIn(c, self.s) 302 else: 303 self.assertNotIn(c, self.s) 304 self.assertRaises(TypeError, self.s.difference_update, [[]]) 305 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 306 307 def test_isub(self): 308 self.s -= set(self.items2) 309 for c in (self.items + self.items2): 310 if c in self.items and c not in self.items2: 311 self.assertIn(c, self.s) 312 else: 313 self.assertNotIn(c, self.s) 314 315 def test_symmetric_difference_update(self): 316 retval = self.s.symmetric_difference_update(self.items2) 317 self.assertEqual(retval, None) 318 for c in (self.items + self.items2): 319 if (c in self.items) ^ (c in self.items2): 320 self.assertIn(c, self.s) 321 else: 322 self.assertNotIn(c, self.s) 323 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 324 325 def test_ixor(self): 326 self.s ^= set(self.items2) 327 for c in (self.items + self.items2): 328 if (c in self.items) ^ (c in self.items2): 329 self.assertIn(c, self.s) 330 else: 331 self.assertNotIn(c, self.s) 332 333 def test_inplace_on_self(self): 334 t = self.s.copy() 335 t |= t 336 self.assertEqual(t, self.s) 337 t &= t 338 self.assertEqual(t, self.s) 339 t -= t 340 self.assertEqual(t, WeakSet()) 341 t = self.s.copy() 342 t ^= t 343 self.assertEqual(t, WeakSet()) 344 345 def test_eq(self): 346 # issue 5964 347 self.assertTrue(self.s == self.s) 348 self.assertTrue(self.s == WeakSet(self.items)) 349 self.assertFalse(self.s == set(self.items)) 350 self.assertFalse(self.s == list(self.items)) 351 self.assertFalse(self.s == tuple(self.items)) 352 self.assertFalse(self.s == 1) 353 354 def test_ne(self): 355 self.assertTrue(self.s != set(self.items)) 356 s1 = WeakSet() 357 s2 = WeakSet() 358 self.assertFalse(s1 != s2) 359 360 def test_weak_destroy_while_iterating(self): 361 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 362 # Create new items to be sure no-one else holds a reference 363 items = [SomeClass(c) for c in ('a', 'b', 'c')] 364 s = WeakSet(items) 365 it = iter(s) 366 next(it) # Trigger internal iteration 367 # Destroy an item 368 del items[-1] 369 gc.collect() # just in case 370 # We have removed either the first consumed items, or another one 371 self.assertIn(len(list(it)), [len(items), len(items) - 1]) 372 del it 373 # The removal has been committed 374 self.assertEqual(len(s), len(items)) 375 376 def test_weak_destroy_and_mutate_while_iterating(self): 377 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 378 items = [SomeClass(c) for c in string.ascii_letters] 379 s = WeakSet(items) 380 @contextlib.contextmanager 381 def testcontext(): 382 try: 383 it = iter(s) 384 next(it) 385 # Schedule an item for removal and recreate it 386 u = SomeClass(str(items.pop())) 387 gc.collect() # just in case 388 yield u 389 finally: 390 it = None # should commit all removals 391 392 with testcontext() as u: 393 self.assertNotIn(u, s) 394 with testcontext() as u: 395 self.assertRaises(KeyError, s.remove, u) 396 self.assertNotIn(u, s) 397 with testcontext() as u: 398 s.add(u) 399 self.assertIn(u, s) 400 t = s.copy() 401 with testcontext() as u: 402 s.update(t) 403 self.assertEqual(len(s), len(t)) 404 with testcontext() as u: 405 s.clear() 406 self.assertEqual(len(s), 0) 407 408 def test_len_cycles(self): 409 N = 20 410 items = [RefCycle() for i in range(N)] 411 s = WeakSet(items) 412 del items 413 it = iter(s) 414 try: 415 next(it) 416 except StopIteration: 417 pass 418 gc.collect() 419 n1 = len(s) 420 del it 421 gc.collect() 422 n2 = len(s) 423 # one item may be kept alive inside the iterator 424 self.assertIn(n1, (0, 1)) 425 self.assertEqual(n2, 0) 426 427 def test_len_race(self): 428 # Extended sanity checks for len() in the face of cyclic collection 429 self.addCleanup(gc.set_threshold, *gc.get_threshold()) 430 for th in range(1, 100): 431 N = 20 432 gc.collect(0) 433 gc.set_threshold(th, th, th) 434 items = [RefCycle() for i in range(N)] 435 s = WeakSet(items) 436 del items 437 # All items will be collected at next garbage collection pass 438 it = iter(s) 439 try: 440 next(it) 441 except StopIteration: 442 pass 443 n1 = len(s) 444 del it 445 n2 = len(s) 446 self.assertGreaterEqual(n1, 0) 447 self.assertLessEqual(n1, N) 448 self.assertGreaterEqual(n2, 0) 449 self.assertLessEqual(n2, n1) 450 451 452def test_main(verbose=None): 453 test_support.run_unittest(TestWeakSet) 454 455if __name__ == "__main__": 456 test_main(verbose=True) 457