1import unittest 2from test import support 3import gc 4import weakref 5import operator 6import copy 7import pickle 8from random import randrange, shuffle 9import warnings 10import collections 11import collections.abc 12import itertools 13 14class PassThru(Exception): 15 pass 16 17def check_pass_thru(): 18 raise PassThru 19 yield 1 20 21class BadCmp: 22 def __hash__(self): 23 return 1 24 def __eq__(self, other): 25 raise RuntimeError 26 27class ReprWrapper: 28 'Used to test self-referential repr() calls' 29 def __repr__(self): 30 return repr(self.value) 31 32class HashCountingInt(int): 33 'int-like object that counts the number of times __hash__ is called' 34 def __init__(self, *args): 35 self.hash_count = 0 36 def __hash__(self): 37 self.hash_count += 1 38 return int.__hash__(self) 39 40class TestJointOps: 41 # Tests common to both set and frozenset 42 43 def setUp(self): 44 self.word = word = 'simsalabim' 45 self.otherword = 'madagascar' 46 self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 47 self.s = self.thetype(word) 48 self.d = dict.fromkeys(word) 49 50 def test_new_or_init(self): 51 self.assertRaises(TypeError, self.thetype, [], 2) 52 self.assertRaises(TypeError, set().__init__, a=1) 53 54 def test_uniquification(self): 55 actual = sorted(self.s) 56 expected = sorted(self.d) 57 self.assertEqual(actual, expected) 58 self.assertRaises(PassThru, self.thetype, check_pass_thru()) 59 self.assertRaises(TypeError, self.thetype, [[]]) 60 61 def test_len(self): 62 self.assertEqual(len(self.s), len(self.d)) 63 64 def test_contains(self): 65 for c in self.letters: 66 self.assertEqual(c in self.s, c in self.d) 67 self.assertRaises(TypeError, self.s.__contains__, [[]]) 68 s = self.thetype([frozenset(self.letters)]) 69 self.assertIn(self.thetype(self.letters), s) 70 71 def test_union(self): 72 u = self.s.union(self.otherword) 73 for c in self.letters: 74 self.assertEqual(c in u, c in self.d or c in self.otherword) 75 self.assertEqual(self.s, self.thetype(self.word)) 76 self.assertEqual(type(u), self.basetype) 77 self.assertRaises(PassThru, self.s.union, check_pass_thru()) 78 self.assertRaises(TypeError, self.s.union, [[]]) 79 for C in set, frozenset, dict.fromkeys, str, list, tuple: 80 self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd')) 81 self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg')) 82 self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc')) 83 self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef')) 84 self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg')) 85 86 # Issue #6573 87 x = self.thetype() 88 self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2])) 89 90 def test_or(self): 91 i = self.s.union(self.otherword) 92 self.assertEqual(self.s | set(self.otherword), i) 93 self.assertEqual(self.s | frozenset(self.otherword), i) 94 try: 95 self.s | self.otherword 96 except TypeError: 97 pass 98 else: 99 self.fail("s|t did not screen-out general iterables") 100 101 def test_intersection(self): 102 i = self.s.intersection(self.otherword) 103 for c in self.letters: 104 self.assertEqual(c in i, c in self.d and c in self.otherword) 105 self.assertEqual(self.s, self.thetype(self.word)) 106 self.assertEqual(type(i), self.basetype) 107 self.assertRaises(PassThru, self.s.intersection, check_pass_thru()) 108 for C in set, frozenset, dict.fromkeys, str, list, tuple: 109 self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc')) 110 self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set('')) 111 self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc')) 112 self.assertEqual(self.thetype('abcba').intersection(C('ef')), set('')) 113 self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b')) 114 s = self.thetype('abcba') 115 z = s.intersection() 116 if self.thetype == frozenset(): 117 self.assertEqual(id(s), id(z)) 118 else: 119 self.assertNotEqual(id(s), id(z)) 120 121 def test_isdisjoint(self): 122 def f(s1, s2): 123 'Pure python equivalent of isdisjoint()' 124 return not set(s1).intersection(s2) 125 for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef': 126 s1 = self.thetype(larg) 127 for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef': 128 for C in set, frozenset, dict.fromkeys, str, list, tuple: 129 s2 = C(rarg) 130 actual = s1.isdisjoint(s2) 131 expected = f(s1, s2) 132 self.assertEqual(actual, expected) 133 self.assertTrue(actual is True or actual is False) 134 135 def test_and(self): 136 i = self.s.intersection(self.otherword) 137 self.assertEqual(self.s & set(self.otherword), i) 138 self.assertEqual(self.s & frozenset(self.otherword), i) 139 try: 140 self.s & self.otherword 141 except TypeError: 142 pass 143 else: 144 self.fail("s&t did not screen-out general iterables") 145 146 def test_difference(self): 147 i = self.s.difference(self.otherword) 148 for c in self.letters: 149 self.assertEqual(c in i, c in self.d and c not in self.otherword) 150 self.assertEqual(self.s, self.thetype(self.word)) 151 self.assertEqual(type(i), self.basetype) 152 self.assertRaises(PassThru, self.s.difference, check_pass_thru()) 153 self.assertRaises(TypeError, self.s.difference, [[]]) 154 for C in set, frozenset, dict.fromkeys, str, list, tuple: 155 self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab')) 156 self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc')) 157 self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a')) 158 self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc')) 159 self.assertEqual(self.thetype('abcba').difference(), set('abc')) 160 self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c')) 161 162 def test_sub(self): 163 i = self.s.difference(self.otherword) 164 self.assertEqual(self.s - set(self.otherword), i) 165 self.assertEqual(self.s - frozenset(self.otherword), i) 166 try: 167 self.s - self.otherword 168 except TypeError: 169 pass 170 else: 171 self.fail("s-t did not screen-out general iterables") 172 173 def test_symmetric_difference(self): 174 i = self.s.symmetric_difference(self.otherword) 175 for c in self.letters: 176 self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword)) 177 self.assertEqual(self.s, self.thetype(self.word)) 178 self.assertEqual(type(i), self.basetype) 179 self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru()) 180 self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) 181 for C in set, frozenset, dict.fromkeys, str, list, tuple: 182 self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd')) 183 self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg')) 184 self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a')) 185 self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef')) 186 187 def test_xor(self): 188 i = self.s.symmetric_difference(self.otherword) 189 self.assertEqual(self.s ^ set(self.otherword), i) 190 self.assertEqual(self.s ^ frozenset(self.otherword), i) 191 try: 192 self.s ^ self.otherword 193 except TypeError: 194 pass 195 else: 196 self.fail("s^t did not screen-out general iterables") 197 198 def test_equality(self): 199 self.assertEqual(self.s, set(self.word)) 200 self.assertEqual(self.s, frozenset(self.word)) 201 self.assertEqual(self.s == self.word, False) 202 self.assertNotEqual(self.s, set(self.otherword)) 203 self.assertNotEqual(self.s, frozenset(self.otherword)) 204 self.assertEqual(self.s != self.word, True) 205 206 def test_setOfFrozensets(self): 207 t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba']) 208 s = self.thetype(t) 209 self.assertEqual(len(s), 3) 210 211 def test_sub_and_super(self): 212 p, q, r = map(self.thetype, ['ab', 'abcde', 'def']) 213 self.assertTrue(p < q) 214 self.assertTrue(p <= q) 215 self.assertTrue(q <= q) 216 self.assertTrue(q > p) 217 self.assertTrue(q >= p) 218 self.assertFalse(q < r) 219 self.assertFalse(q <= r) 220 self.assertFalse(q > r) 221 self.assertFalse(q >= r) 222 self.assertTrue(set('a').issubset('abc')) 223 self.assertTrue(set('abc').issuperset('a')) 224 self.assertFalse(set('a').issubset('cbs')) 225 self.assertFalse(set('cbs').issuperset('a')) 226 227 def test_pickling(self): 228 for i in range(pickle.HIGHEST_PROTOCOL + 1): 229 p = pickle.dumps(self.s, i) 230 dup = pickle.loads(p) 231 self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup)) 232 if type(self.s) not in (set, frozenset): 233 self.s.x = 10 234 p = pickle.dumps(self.s, i) 235 dup = pickle.loads(p) 236 self.assertEqual(self.s.x, dup.x) 237 238 def test_iterator_pickling(self): 239 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 240 itorg = iter(self.s) 241 data = self.thetype(self.s) 242 d = pickle.dumps(itorg, proto) 243 it = pickle.loads(d) 244 # Set iterators unpickle as list iterators due to the 245 # undefined order of set items. 246 # self.assertEqual(type(itorg), type(it)) 247 self.assertIsInstance(it, collections.abc.Iterator) 248 self.assertEqual(self.thetype(it), data) 249 250 it = pickle.loads(d) 251 try: 252 drop = next(it) 253 except StopIteration: 254 continue 255 d = pickle.dumps(it, proto) 256 it = pickle.loads(d) 257 self.assertEqual(self.thetype(it), data - self.thetype((drop,))) 258 259 def test_deepcopy(self): 260 class Tracer: 261 def __init__(self, value): 262 self.value = value 263 def __hash__(self): 264 return self.value 265 def __deepcopy__(self, memo=None): 266 return Tracer(self.value + 1) 267 t = Tracer(10) 268 s = self.thetype([t]) 269 dup = copy.deepcopy(s) 270 self.assertNotEqual(id(s), id(dup)) 271 for elem in dup: 272 newt = elem 273 self.assertNotEqual(id(t), id(newt)) 274 self.assertEqual(t.value + 1, newt.value) 275 276 def test_gc(self): 277 # Create a nest of cycles to exercise overall ref count check 278 class A: 279 pass 280 s = set(A() for i in range(1000)) 281 for elem in s: 282 elem.cycle = s 283 elem.sub = elem 284 elem.set = set([elem]) 285 286 def test_subclass_with_custom_hash(self): 287 # Bug #1257731 288 class H(self.thetype): 289 def __hash__(self): 290 return int(id(self) & 0x7fffffff) 291 s=H() 292 f=set() 293 f.add(s) 294 self.assertIn(s, f) 295 f.remove(s) 296 f.add(s) 297 f.discard(s) 298 299 def test_badcmp(self): 300 s = self.thetype([BadCmp()]) 301 # Detect comparison errors during insertion and lookup 302 self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()]) 303 self.assertRaises(RuntimeError, s.__contains__, BadCmp()) 304 # Detect errors during mutating operations 305 if hasattr(s, 'add'): 306 self.assertRaises(RuntimeError, s.add, BadCmp()) 307 self.assertRaises(RuntimeError, s.discard, BadCmp()) 308 self.assertRaises(RuntimeError, s.remove, BadCmp()) 309 310 def test_cyclical_repr(self): 311 w = ReprWrapper() 312 s = self.thetype([w]) 313 w.value = s 314 if self.thetype == set: 315 self.assertEqual(repr(s), '{set(...)}') 316 else: 317 name = repr(s).partition('(')[0] # strip class name 318 self.assertEqual(repr(s), '%s({%s(...)})' % (name, name)) 319 320 def test_cyclical_print(self): 321 w = ReprWrapper() 322 s = self.thetype([w]) 323 w.value = s 324 fo = open(support.TESTFN, "w") 325 try: 326 fo.write(str(s)) 327 fo.close() 328 fo = open(support.TESTFN, "r") 329 self.assertEqual(fo.read(), repr(s)) 330 finally: 331 fo.close() 332 support.unlink(support.TESTFN) 333 334 def test_do_not_rehash_dict_keys(self): 335 n = 10 336 d = dict.fromkeys(map(HashCountingInt, range(n))) 337 self.assertEqual(sum(elem.hash_count for elem in d), n) 338 s = self.thetype(d) 339 self.assertEqual(sum(elem.hash_count for elem in d), n) 340 s.difference(d) 341 self.assertEqual(sum(elem.hash_count for elem in d), n) 342 if hasattr(s, 'symmetric_difference_update'): 343 s.symmetric_difference_update(d) 344 self.assertEqual(sum(elem.hash_count for elem in d), n) 345 d2 = dict.fromkeys(set(d)) 346 self.assertEqual(sum(elem.hash_count for elem in d), n) 347 d3 = dict.fromkeys(frozenset(d)) 348 self.assertEqual(sum(elem.hash_count for elem in d), n) 349 d3 = dict.fromkeys(frozenset(d), 123) 350 self.assertEqual(sum(elem.hash_count for elem in d), n) 351 self.assertEqual(d3, dict.fromkeys(d, 123)) 352 353 def test_container_iterator(self): 354 # Bug #3680: tp_traverse was not implemented for set iterator object 355 class C(object): 356 pass 357 obj = C() 358 ref = weakref.ref(obj) 359 container = set([obj, 1]) 360 obj.x = iter(container) 361 del obj, container 362 gc.collect() 363 self.assertTrue(ref() is None, "Cycle was not collected") 364 365 def test_free_after_iterating(self): 366 support.check_free_after_iterating(self, iter, self.thetype) 367 368class TestSet(TestJointOps, unittest.TestCase): 369 thetype = set 370 basetype = set 371 372 def test_init(self): 373 s = self.thetype() 374 s.__init__(self.word) 375 self.assertEqual(s, set(self.word)) 376 s.__init__(self.otherword) 377 self.assertEqual(s, set(self.otherword)) 378 self.assertRaises(TypeError, s.__init__, s, 2); 379 self.assertRaises(TypeError, s.__init__, 1); 380 381 def test_constructor_identity(self): 382 s = self.thetype(range(3)) 383 t = self.thetype(s) 384 self.assertNotEqual(id(s), id(t)) 385 386 def test_set_literal(self): 387 s = set([1,2,3]) 388 t = {1,2,3} 389 self.assertEqual(s, t) 390 391 def test_set_literal_insertion_order(self): 392 # SF Issue #26020 -- Expect left to right insertion 393 s = {1, 1.0, True} 394 self.assertEqual(len(s), 1) 395 stored_value = s.pop() 396 self.assertEqual(type(stored_value), int) 397 398 def test_set_literal_evaluation_order(self): 399 # Expect left to right expression evaluation 400 events = [] 401 def record(obj): 402 events.append(obj) 403 s = {record(1), record(2), record(3)} 404 self.assertEqual(events, [1, 2, 3]) 405 406 def test_hash(self): 407 self.assertRaises(TypeError, hash, self.s) 408 409 def test_clear(self): 410 self.s.clear() 411 self.assertEqual(self.s, set()) 412 self.assertEqual(len(self.s), 0) 413 414 def test_copy(self): 415 dup = self.s.copy() 416 self.assertEqual(self.s, dup) 417 self.assertNotEqual(id(self.s), id(dup)) 418 self.assertEqual(type(dup), self.basetype) 419 420 def test_add(self): 421 self.s.add('Q') 422 self.assertIn('Q', self.s) 423 dup = self.s.copy() 424 self.s.add('Q') 425 self.assertEqual(self.s, dup) 426 self.assertRaises(TypeError, self.s.add, []) 427 428 def test_remove(self): 429 self.s.remove('a') 430 self.assertNotIn('a', self.s) 431 self.assertRaises(KeyError, self.s.remove, 'Q') 432 self.assertRaises(TypeError, self.s.remove, []) 433 s = self.thetype([frozenset(self.word)]) 434 self.assertIn(self.thetype(self.word), s) 435 s.remove(self.thetype(self.word)) 436 self.assertNotIn(self.thetype(self.word), s) 437 self.assertRaises(KeyError, self.s.remove, self.thetype(self.word)) 438 439 def test_remove_keyerror_unpacking(self): 440 # bug: www.python.org/sf/1576657 441 for v1 in ['Q', (1,)]: 442 try: 443 self.s.remove(v1) 444 except KeyError as e: 445 v2 = e.args[0] 446 self.assertEqual(v1, v2) 447 else: 448 self.fail() 449 450 def test_remove_keyerror_set(self): 451 key = self.thetype([3, 4]) 452 try: 453 self.s.remove(key) 454 except KeyError as e: 455 self.assertTrue(e.args[0] is key, 456 "KeyError should be {0}, not {1}".format(key, 457 e.args[0])) 458 else: 459 self.fail() 460 461 def test_discard(self): 462 self.s.discard('a') 463 self.assertNotIn('a', self.s) 464 self.s.discard('Q') 465 self.assertRaises(TypeError, self.s.discard, []) 466 s = self.thetype([frozenset(self.word)]) 467 self.assertIn(self.thetype(self.word), s) 468 s.discard(self.thetype(self.word)) 469 self.assertNotIn(self.thetype(self.word), s) 470 s.discard(self.thetype(self.word)) 471 472 def test_pop(self): 473 for i in range(len(self.s)): 474 elem = self.s.pop() 475 self.assertNotIn(elem, self.s) 476 self.assertRaises(KeyError, self.s.pop) 477 478 def test_update(self): 479 retval = self.s.update(self.otherword) 480 self.assertEqual(retval, None) 481 for c in (self.word + self.otherword): 482 self.assertIn(c, self.s) 483 self.assertRaises(PassThru, self.s.update, check_pass_thru()) 484 self.assertRaises(TypeError, self.s.update, [[]]) 485 for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')): 486 for C in set, frozenset, dict.fromkeys, str, list, tuple: 487 s = self.thetype('abcba') 488 self.assertEqual(s.update(C(p)), None) 489 self.assertEqual(s, set(q)) 490 for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'): 491 q = 'ahi' 492 for C in set, frozenset, dict.fromkeys, str, list, tuple: 493 s = self.thetype('abcba') 494 self.assertEqual(s.update(C(p), C(q)), None) 495 self.assertEqual(s, set(s) | set(p) | set(q)) 496 497 def test_ior(self): 498 self.s |= set(self.otherword) 499 for c in (self.word + self.otherword): 500 self.assertIn(c, self.s) 501 502 def test_intersection_update(self): 503 retval = self.s.intersection_update(self.otherword) 504 self.assertEqual(retval, None) 505 for c in (self.word + self.otherword): 506 if c in self.otherword and c in self.word: 507 self.assertIn(c, self.s) 508 else: 509 self.assertNotIn(c, self.s) 510 self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru()) 511 self.assertRaises(TypeError, self.s.intersection_update, [[]]) 512 for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')): 513 for C in set, frozenset, dict.fromkeys, str, list, tuple: 514 s = self.thetype('abcba') 515 self.assertEqual(s.intersection_update(C(p)), None) 516 self.assertEqual(s, set(q)) 517 ss = 'abcba' 518 s = self.thetype(ss) 519 t = 'cbc' 520 self.assertEqual(s.intersection_update(C(p), C(t)), None) 521 self.assertEqual(s, set('abcba')&set(p)&set(t)) 522 523 def test_iand(self): 524 self.s &= set(self.otherword) 525 for c in (self.word + self.otherword): 526 if c in self.otherword and c in self.word: 527 self.assertIn(c, self.s) 528 else: 529 self.assertNotIn(c, self.s) 530 531 def test_difference_update(self): 532 retval = self.s.difference_update(self.otherword) 533 self.assertEqual(retval, None) 534 for c in (self.word + self.otherword): 535 if c in self.word and c not in self.otherword: 536 self.assertIn(c, self.s) 537 else: 538 self.assertNotIn(c, self.s) 539 self.assertRaises(PassThru, self.s.difference_update, check_pass_thru()) 540 self.assertRaises(TypeError, self.s.difference_update, [[]]) 541 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 542 for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')): 543 for C in set, frozenset, dict.fromkeys, str, list, tuple: 544 s = self.thetype('abcba') 545 self.assertEqual(s.difference_update(C(p)), None) 546 self.assertEqual(s, set(q)) 547 548 s = self.thetype('abcdefghih') 549 s.difference_update() 550 self.assertEqual(s, self.thetype('abcdefghih')) 551 552 s = self.thetype('abcdefghih') 553 s.difference_update(C('aba')) 554 self.assertEqual(s, self.thetype('cdefghih')) 555 556 s = self.thetype('abcdefghih') 557 s.difference_update(C('cdc'), C('aba')) 558 self.assertEqual(s, self.thetype('efghih')) 559 560 def test_isub(self): 561 self.s -= set(self.otherword) 562 for c in (self.word + self.otherword): 563 if c in self.word and c not in self.otherword: 564 self.assertIn(c, self.s) 565 else: 566 self.assertNotIn(c, self.s) 567 568 def test_symmetric_difference_update(self): 569 retval = self.s.symmetric_difference_update(self.otherword) 570 self.assertEqual(retval, None) 571 for c in (self.word + self.otherword): 572 if (c in self.word) ^ (c in self.otherword): 573 self.assertIn(c, self.s) 574 else: 575 self.assertNotIn(c, self.s) 576 self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru()) 577 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 578 for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')): 579 for C in set, frozenset, dict.fromkeys, str, list, tuple: 580 s = self.thetype('abcba') 581 self.assertEqual(s.symmetric_difference_update(C(p)), None) 582 self.assertEqual(s, set(q)) 583 584 def test_ixor(self): 585 self.s ^= set(self.otherword) 586 for c in (self.word + self.otherword): 587 if (c in self.word) ^ (c in self.otherword): 588 self.assertIn(c, self.s) 589 else: 590 self.assertNotIn(c, self.s) 591 592 def test_inplace_on_self(self): 593 t = self.s.copy() 594 t |= t 595 self.assertEqual(t, self.s) 596 t &= t 597 self.assertEqual(t, self.s) 598 t -= t 599 self.assertEqual(t, self.thetype()) 600 t = self.s.copy() 601 t ^= t 602 self.assertEqual(t, self.thetype()) 603 604 def test_weakref(self): 605 s = self.thetype('gallahad') 606 p = weakref.proxy(s) 607 self.assertEqual(str(p), str(s)) 608 s = None 609 self.assertRaises(ReferenceError, str, p) 610 611 def test_rich_compare(self): 612 class TestRichSetCompare: 613 def __gt__(self, some_set): 614 self.gt_called = True 615 return False 616 def __lt__(self, some_set): 617 self.lt_called = True 618 return False 619 def __ge__(self, some_set): 620 self.ge_called = True 621 return False 622 def __le__(self, some_set): 623 self.le_called = True 624 return False 625 626 # This first tries the builtin rich set comparison, which doesn't know 627 # how to handle the custom object. Upon returning NotImplemented, the 628 # corresponding comparison on the right object is invoked. 629 myset = {1, 2, 3} 630 631 myobj = TestRichSetCompare() 632 myset < myobj 633 self.assertTrue(myobj.gt_called) 634 635 myobj = TestRichSetCompare() 636 myset > myobj 637 self.assertTrue(myobj.lt_called) 638 639 myobj = TestRichSetCompare() 640 myset <= myobj 641 self.assertTrue(myobj.ge_called) 642 643 myobj = TestRichSetCompare() 644 myset >= myobj 645 self.assertTrue(myobj.le_called) 646 647 @unittest.skipUnless(hasattr(set, "test_c_api"), 648 'C API test only available in a debug build') 649 def test_c_api(self): 650 self.assertEqual(set().test_c_api(), True) 651 652class SetSubclass(set): 653 pass 654 655class TestSetSubclass(TestSet): 656 thetype = SetSubclass 657 basetype = set 658 659class SetSubclassWithKeywordArgs(set): 660 def __init__(self, iterable=[], newarg=None): 661 set.__init__(self, iterable) 662 663class TestSetSubclassWithKeywordArgs(TestSet): 664 665 def test_keywords_in_subclass(self): 666 'SF bug #1486663 -- this used to erroneously raise a TypeError' 667 SetSubclassWithKeywordArgs(newarg=1) 668 669class TestFrozenSet(TestJointOps, unittest.TestCase): 670 thetype = frozenset 671 basetype = frozenset 672 673 def test_init(self): 674 s = self.thetype(self.word) 675 s.__init__(self.otherword) 676 self.assertEqual(s, set(self.word)) 677 678 def test_singleton_empty_frozenset(self): 679 f = frozenset() 680 efs = [frozenset(), frozenset([]), frozenset(()), frozenset(''), 681 frozenset(), frozenset([]), frozenset(()), frozenset(''), 682 frozenset(range(0)), frozenset(frozenset()), 683 frozenset(f), f] 684 # All of the empty frozensets should have just one id() 685 self.assertEqual(len(set(map(id, efs))), 1) 686 687 def test_constructor_identity(self): 688 s = self.thetype(range(3)) 689 t = self.thetype(s) 690 self.assertEqual(id(s), id(t)) 691 692 def test_hash(self): 693 self.assertEqual(hash(self.thetype('abcdeb')), 694 hash(self.thetype('ebecda'))) 695 696 # make sure that all permutations give the same hash value 697 n = 100 698 seq = [randrange(n) for i in range(n)] 699 results = set() 700 for i in range(200): 701 shuffle(seq) 702 results.add(hash(self.thetype(seq))) 703 self.assertEqual(len(results), 1) 704 705 def test_copy(self): 706 dup = self.s.copy() 707 self.assertEqual(id(self.s), id(dup)) 708 709 def test_frozen_as_dictkey(self): 710 seq = list(range(10)) + list('abcdefg') + ['apple'] 711 key1 = self.thetype(seq) 712 key2 = self.thetype(reversed(seq)) 713 self.assertEqual(key1, key2) 714 self.assertNotEqual(id(key1), id(key2)) 715 d = {} 716 d[key1] = 42 717 self.assertEqual(d[key2], 42) 718 719 def test_hash_caching(self): 720 f = self.thetype('abcdcda') 721 self.assertEqual(hash(f), hash(f)) 722 723 def test_hash_effectiveness(self): 724 n = 13 725 hashvalues = set() 726 addhashvalue = hashvalues.add 727 elemmasks = [(i+1, 1<<i) for i in range(n)] 728 for i in range(2**n): 729 addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i]))) 730 self.assertEqual(len(hashvalues), 2**n) 731 732 def zf_range(n): 733 # https://en.wikipedia.org/wiki/Set-theoretic_definition_of_natural_numbers 734 nums = [frozenset()] 735 for i in range(n-1): 736 num = frozenset(nums) 737 nums.append(num) 738 return nums[:n] 739 740 def powerset(s): 741 for i in range(len(s)+1): 742 yield from map(frozenset, itertools.combinations(s, i)) 743 744 for n in range(18): 745 t = 2 ** n 746 mask = t - 1 747 for nums in (range, zf_range): 748 u = len({h & mask for h in map(hash, powerset(nums(n)))}) 749 self.assertGreater(4*u, t) 750 751class FrozenSetSubclass(frozenset): 752 pass 753 754class TestFrozenSetSubclass(TestFrozenSet): 755 thetype = FrozenSetSubclass 756 basetype = frozenset 757 758 def test_constructor_identity(self): 759 s = self.thetype(range(3)) 760 t = self.thetype(s) 761 self.assertNotEqual(id(s), id(t)) 762 763 def test_copy(self): 764 dup = self.s.copy() 765 self.assertNotEqual(id(self.s), id(dup)) 766 767 def test_nested_empty_constructor(self): 768 s = self.thetype() 769 t = self.thetype(s) 770 self.assertEqual(s, t) 771 772 def test_singleton_empty_frozenset(self): 773 Frozenset = self.thetype 774 f = frozenset() 775 F = Frozenset() 776 efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''), 777 Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''), 778 Frozenset(range(0)), Frozenset(Frozenset()), 779 Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)] 780 # All empty frozenset subclass instances should have different ids 781 self.assertEqual(len(set(map(id, efs))), len(efs)) 782 783# Tests taken from test_sets.py ============================================= 784 785empty_set = set() 786 787#============================================================================== 788 789class TestBasicOps: 790 791 def test_repr(self): 792 if self.repr is not None: 793 self.assertEqual(repr(self.set), self.repr) 794 795 def check_repr_against_values(self): 796 text = repr(self.set) 797 self.assertTrue(text.startswith('{')) 798 self.assertTrue(text.endswith('}')) 799 800 result = text[1:-1].split(', ') 801 result.sort() 802 sorted_repr_values = [repr(value) for value in self.values] 803 sorted_repr_values.sort() 804 self.assertEqual(result, sorted_repr_values) 805 806 def test_print(self): 807 try: 808 fo = open(support.TESTFN, "w") 809 fo.write(str(self.set)) 810 fo.close() 811 fo = open(support.TESTFN, "r") 812 self.assertEqual(fo.read(), repr(self.set)) 813 finally: 814 fo.close() 815 support.unlink(support.TESTFN) 816 817 def test_length(self): 818 self.assertEqual(len(self.set), self.length) 819 820 def test_self_equality(self): 821 self.assertEqual(self.set, self.set) 822 823 def test_equivalent_equality(self): 824 self.assertEqual(self.set, self.dup) 825 826 def test_copy(self): 827 self.assertEqual(self.set.copy(), self.dup) 828 829 def test_self_union(self): 830 result = self.set | self.set 831 self.assertEqual(result, self.dup) 832 833 def test_empty_union(self): 834 result = self.set | empty_set 835 self.assertEqual(result, self.dup) 836 837 def test_union_empty(self): 838 result = empty_set | self.set 839 self.assertEqual(result, self.dup) 840 841 def test_self_intersection(self): 842 result = self.set & self.set 843 self.assertEqual(result, self.dup) 844 845 def test_empty_intersection(self): 846 result = self.set & empty_set 847 self.assertEqual(result, empty_set) 848 849 def test_intersection_empty(self): 850 result = empty_set & self.set 851 self.assertEqual(result, empty_set) 852 853 def test_self_isdisjoint(self): 854 result = self.set.isdisjoint(self.set) 855 self.assertEqual(result, not self.set) 856 857 def test_empty_isdisjoint(self): 858 result = self.set.isdisjoint(empty_set) 859 self.assertEqual(result, True) 860 861 def test_isdisjoint_empty(self): 862 result = empty_set.isdisjoint(self.set) 863 self.assertEqual(result, True) 864 865 def test_self_symmetric_difference(self): 866 result = self.set ^ self.set 867 self.assertEqual(result, empty_set) 868 869 def test_empty_symmetric_difference(self): 870 result = self.set ^ empty_set 871 self.assertEqual(result, self.set) 872 873 def test_self_difference(self): 874 result = self.set - self.set 875 self.assertEqual(result, empty_set) 876 877 def test_empty_difference(self): 878 result = self.set - empty_set 879 self.assertEqual(result, self.dup) 880 881 def test_empty_difference_rev(self): 882 result = empty_set - self.set 883 self.assertEqual(result, empty_set) 884 885 def test_iteration(self): 886 for v in self.set: 887 self.assertIn(v, self.values) 888 setiter = iter(self.set) 889 self.assertEqual(setiter.__length_hint__(), len(self.set)) 890 891 def test_pickling(self): 892 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 893 p = pickle.dumps(self.set, proto) 894 copy = pickle.loads(p) 895 self.assertEqual(self.set, copy, 896 "%s != %s" % (self.set, copy)) 897 898 def test_issue_37219(self): 899 with self.assertRaises(TypeError): 900 set().difference(123) 901 with self.assertRaises(TypeError): 902 set().difference_update(123) 903 904#------------------------------------------------------------------------------ 905 906class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase): 907 def setUp(self): 908 self.case = "empty set" 909 self.values = [] 910 self.set = set(self.values) 911 self.dup = set(self.values) 912 self.length = 0 913 self.repr = "set()" 914 915#------------------------------------------------------------------------------ 916 917class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase): 918 def setUp(self): 919 self.case = "unit set (number)" 920 self.values = [3] 921 self.set = set(self.values) 922 self.dup = set(self.values) 923 self.length = 1 924 self.repr = "{3}" 925 926 def test_in(self): 927 self.assertIn(3, self.set) 928 929 def test_not_in(self): 930 self.assertNotIn(2, self.set) 931 932#------------------------------------------------------------------------------ 933 934class TestBasicOpsTuple(TestBasicOps, unittest.TestCase): 935 def setUp(self): 936 self.case = "unit set (tuple)" 937 self.values = [(0, "zero")] 938 self.set = set(self.values) 939 self.dup = set(self.values) 940 self.length = 1 941 self.repr = "{(0, 'zero')}" 942 943 def test_in(self): 944 self.assertIn((0, "zero"), self.set) 945 946 def test_not_in(self): 947 self.assertNotIn(9, self.set) 948 949#------------------------------------------------------------------------------ 950 951class TestBasicOpsTriple(TestBasicOps, unittest.TestCase): 952 def setUp(self): 953 self.case = "triple set" 954 self.values = [0, "zero", operator.add] 955 self.set = set(self.values) 956 self.dup = set(self.values) 957 self.length = 3 958 self.repr = None 959 960#------------------------------------------------------------------------------ 961 962class TestBasicOpsString(TestBasicOps, unittest.TestCase): 963 def setUp(self): 964 self.case = "string set" 965 self.values = ["a", "b", "c"] 966 self.set = set(self.values) 967 self.dup = set(self.values) 968 self.length = 3 969 970 def test_repr(self): 971 self.check_repr_against_values() 972 973#------------------------------------------------------------------------------ 974 975class TestBasicOpsBytes(TestBasicOps, unittest.TestCase): 976 def setUp(self): 977 self.case = "bytes set" 978 self.values = [b"a", b"b", b"c"] 979 self.set = set(self.values) 980 self.dup = set(self.values) 981 self.length = 3 982 983 def test_repr(self): 984 self.check_repr_against_values() 985 986#------------------------------------------------------------------------------ 987 988class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): 989 def setUp(self): 990 self._warning_filters = support.check_warnings() 991 self._warning_filters.__enter__() 992 warnings.simplefilter('ignore', BytesWarning) 993 self.case = "string and bytes set" 994 self.values = ["a", "b", b"a", b"b"] 995 self.set = set(self.values) 996 self.dup = set(self.values) 997 self.length = 4 998 999 def tearDown(self): 1000 self._warning_filters.__exit__(None, None, None) 1001 1002 def test_repr(self): 1003 self.check_repr_against_values() 1004 1005#============================================================================== 1006 1007def baditer(): 1008 raise TypeError 1009 yield True 1010 1011def gooditer(): 1012 yield True 1013 1014class TestExceptionPropagation(unittest.TestCase): 1015 """SF 628246: Set constructor should not trap iterator TypeErrors""" 1016 1017 def test_instanceWithException(self): 1018 self.assertRaises(TypeError, set, baditer()) 1019 1020 def test_instancesWithoutException(self): 1021 # All of these iterables should load without exception. 1022 set([1,2,3]) 1023 set((1,2,3)) 1024 set({'one':1, 'two':2, 'three':3}) 1025 set(range(3)) 1026 set('abc') 1027 set(gooditer()) 1028 1029 def test_changingSizeWhileIterating(self): 1030 s = set([1,2,3]) 1031 try: 1032 for i in s: 1033 s.update([4]) 1034 except RuntimeError: 1035 pass 1036 else: 1037 self.fail("no exception when changing size during iteration") 1038 1039#============================================================================== 1040 1041class TestSetOfSets(unittest.TestCase): 1042 def test_constructor(self): 1043 inner = frozenset([1]) 1044 outer = set([inner]) 1045 element = outer.pop() 1046 self.assertEqual(type(element), frozenset) 1047 outer.add(inner) # Rebuild set of sets with .add method 1048 outer.remove(inner) 1049 self.assertEqual(outer, set()) # Verify that remove worked 1050 outer.discard(inner) # Absence of KeyError indicates working fine 1051 1052#============================================================================== 1053 1054class TestBinaryOps(unittest.TestCase): 1055 def setUp(self): 1056 self.set = set((2, 4, 6)) 1057 1058 def test_eq(self): # SF bug 643115 1059 self.assertEqual(self.set, set({2:1,4:3,6:5})) 1060 1061 def test_union_subset(self): 1062 result = self.set | set([2]) 1063 self.assertEqual(result, set((2, 4, 6))) 1064 1065 def test_union_superset(self): 1066 result = self.set | set([2, 4, 6, 8]) 1067 self.assertEqual(result, set([2, 4, 6, 8])) 1068 1069 def test_union_overlap(self): 1070 result = self.set | set([3, 4, 5]) 1071 self.assertEqual(result, set([2, 3, 4, 5, 6])) 1072 1073 def test_union_non_overlap(self): 1074 result = self.set | set([8]) 1075 self.assertEqual(result, set([2, 4, 6, 8])) 1076 1077 def test_intersection_subset(self): 1078 result = self.set & set((2, 4)) 1079 self.assertEqual(result, set((2, 4))) 1080 1081 def test_intersection_superset(self): 1082 result = self.set & set([2, 4, 6, 8]) 1083 self.assertEqual(result, set([2, 4, 6])) 1084 1085 def test_intersection_overlap(self): 1086 result = self.set & set([3, 4, 5]) 1087 self.assertEqual(result, set([4])) 1088 1089 def test_intersection_non_overlap(self): 1090 result = self.set & set([8]) 1091 self.assertEqual(result, empty_set) 1092 1093 def test_isdisjoint_subset(self): 1094 result = self.set.isdisjoint(set((2, 4))) 1095 self.assertEqual(result, False) 1096 1097 def test_isdisjoint_superset(self): 1098 result = self.set.isdisjoint(set([2, 4, 6, 8])) 1099 self.assertEqual(result, False) 1100 1101 def test_isdisjoint_overlap(self): 1102 result = self.set.isdisjoint(set([3, 4, 5])) 1103 self.assertEqual(result, False) 1104 1105 def test_isdisjoint_non_overlap(self): 1106 result = self.set.isdisjoint(set([8])) 1107 self.assertEqual(result, True) 1108 1109 def test_sym_difference_subset(self): 1110 result = self.set ^ set((2, 4)) 1111 self.assertEqual(result, set([6])) 1112 1113 def test_sym_difference_superset(self): 1114 result = self.set ^ set((2, 4, 6, 8)) 1115 self.assertEqual(result, set([8])) 1116 1117 def test_sym_difference_overlap(self): 1118 result = self.set ^ set((3, 4, 5)) 1119 self.assertEqual(result, set([2, 3, 5, 6])) 1120 1121 def test_sym_difference_non_overlap(self): 1122 result = self.set ^ set([8]) 1123 self.assertEqual(result, set([2, 4, 6, 8])) 1124 1125#============================================================================== 1126 1127class TestUpdateOps(unittest.TestCase): 1128 def setUp(self): 1129 self.set = set((2, 4, 6)) 1130 1131 def test_union_subset(self): 1132 self.set |= set([2]) 1133 self.assertEqual(self.set, set((2, 4, 6))) 1134 1135 def test_union_superset(self): 1136 self.set |= set([2, 4, 6, 8]) 1137 self.assertEqual(self.set, set([2, 4, 6, 8])) 1138 1139 def test_union_overlap(self): 1140 self.set |= set([3, 4, 5]) 1141 self.assertEqual(self.set, set([2, 3, 4, 5, 6])) 1142 1143 def test_union_non_overlap(self): 1144 self.set |= set([8]) 1145 self.assertEqual(self.set, set([2, 4, 6, 8])) 1146 1147 def test_union_method_call(self): 1148 self.set.update(set([3, 4, 5])) 1149 self.assertEqual(self.set, set([2, 3, 4, 5, 6])) 1150 1151 def test_intersection_subset(self): 1152 self.set &= set((2, 4)) 1153 self.assertEqual(self.set, set((2, 4))) 1154 1155 def test_intersection_superset(self): 1156 self.set &= set([2, 4, 6, 8]) 1157 self.assertEqual(self.set, set([2, 4, 6])) 1158 1159 def test_intersection_overlap(self): 1160 self.set &= set([3, 4, 5]) 1161 self.assertEqual(self.set, set([4])) 1162 1163 def test_intersection_non_overlap(self): 1164 self.set &= set([8]) 1165 self.assertEqual(self.set, empty_set) 1166 1167 def test_intersection_method_call(self): 1168 self.set.intersection_update(set([3, 4, 5])) 1169 self.assertEqual(self.set, set([4])) 1170 1171 def test_sym_difference_subset(self): 1172 self.set ^= set((2, 4)) 1173 self.assertEqual(self.set, set([6])) 1174 1175 def test_sym_difference_superset(self): 1176 self.set ^= set((2, 4, 6, 8)) 1177 self.assertEqual(self.set, set([8])) 1178 1179 def test_sym_difference_overlap(self): 1180 self.set ^= set((3, 4, 5)) 1181 self.assertEqual(self.set, set([2, 3, 5, 6])) 1182 1183 def test_sym_difference_non_overlap(self): 1184 self.set ^= set([8]) 1185 self.assertEqual(self.set, set([2, 4, 6, 8])) 1186 1187 def test_sym_difference_method_call(self): 1188 self.set.symmetric_difference_update(set([3, 4, 5])) 1189 self.assertEqual(self.set, set([2, 3, 5, 6])) 1190 1191 def test_difference_subset(self): 1192 self.set -= set((2, 4)) 1193 self.assertEqual(self.set, set([6])) 1194 1195 def test_difference_superset(self): 1196 self.set -= set((2, 4, 6, 8)) 1197 self.assertEqual(self.set, set([])) 1198 1199 def test_difference_overlap(self): 1200 self.set -= set((3, 4, 5)) 1201 self.assertEqual(self.set, set([2, 6])) 1202 1203 def test_difference_non_overlap(self): 1204 self.set -= set([8]) 1205 self.assertEqual(self.set, set([2, 4, 6])) 1206 1207 def test_difference_method_call(self): 1208 self.set.difference_update(set([3, 4, 5])) 1209 self.assertEqual(self.set, set([2, 6])) 1210 1211#============================================================================== 1212 1213class TestMutate(unittest.TestCase): 1214 def setUp(self): 1215 self.values = ["a", "b", "c"] 1216 self.set = set(self.values) 1217 1218 def test_add_present(self): 1219 self.set.add("c") 1220 self.assertEqual(self.set, set("abc")) 1221 1222 def test_add_absent(self): 1223 self.set.add("d") 1224 self.assertEqual(self.set, set("abcd")) 1225 1226 def test_add_until_full(self): 1227 tmp = set() 1228 expected_len = 0 1229 for v in self.values: 1230 tmp.add(v) 1231 expected_len += 1 1232 self.assertEqual(len(tmp), expected_len) 1233 self.assertEqual(tmp, self.set) 1234 1235 def test_remove_present(self): 1236 self.set.remove("b") 1237 self.assertEqual(self.set, set("ac")) 1238 1239 def test_remove_absent(self): 1240 try: 1241 self.set.remove("d") 1242 self.fail("Removing missing element should have raised LookupError") 1243 except LookupError: 1244 pass 1245 1246 def test_remove_until_empty(self): 1247 expected_len = len(self.set) 1248 for v in self.values: 1249 self.set.remove(v) 1250 expected_len -= 1 1251 self.assertEqual(len(self.set), expected_len) 1252 1253 def test_discard_present(self): 1254 self.set.discard("c") 1255 self.assertEqual(self.set, set("ab")) 1256 1257 def test_discard_absent(self): 1258 self.set.discard("d") 1259 self.assertEqual(self.set, set("abc")) 1260 1261 def test_clear(self): 1262 self.set.clear() 1263 self.assertEqual(len(self.set), 0) 1264 1265 def test_pop(self): 1266 popped = {} 1267 while self.set: 1268 popped[self.set.pop()] = None 1269 self.assertEqual(len(popped), len(self.values)) 1270 for v in self.values: 1271 self.assertIn(v, popped) 1272 1273 def test_update_empty_tuple(self): 1274 self.set.update(()) 1275 self.assertEqual(self.set, set(self.values)) 1276 1277 def test_update_unit_tuple_overlap(self): 1278 self.set.update(("a",)) 1279 self.assertEqual(self.set, set(self.values)) 1280 1281 def test_update_unit_tuple_non_overlap(self): 1282 self.set.update(("a", "z")) 1283 self.assertEqual(self.set, set(self.values + ["z"])) 1284 1285#============================================================================== 1286 1287class TestSubsets: 1288 1289 case2method = {"<=": "issubset", 1290 ">=": "issuperset", 1291 } 1292 1293 reverse = {"==": "==", 1294 "!=": "!=", 1295 "<": ">", 1296 ">": "<", 1297 "<=": ">=", 1298 ">=": "<=", 1299 } 1300 1301 def test_issubset(self): 1302 x = self.left 1303 y = self.right 1304 for case in "!=", "==", "<", "<=", ">", ">=": 1305 expected = case in self.cases 1306 # Test the binary infix spelling. 1307 result = eval("x" + case + "y", locals()) 1308 self.assertEqual(result, expected) 1309 # Test the "friendly" method-name spelling, if one exists. 1310 if case in TestSubsets.case2method: 1311 method = getattr(x, TestSubsets.case2method[case]) 1312 result = method(y) 1313 self.assertEqual(result, expected) 1314 1315 # Now do the same for the operands reversed. 1316 rcase = TestSubsets.reverse[case] 1317 result = eval("y" + rcase + "x", locals()) 1318 self.assertEqual(result, expected) 1319 if rcase in TestSubsets.case2method: 1320 method = getattr(y, TestSubsets.case2method[rcase]) 1321 result = method(x) 1322 self.assertEqual(result, expected) 1323#------------------------------------------------------------------------------ 1324 1325class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): 1326 left = set() 1327 right = set() 1328 name = "both empty" 1329 cases = "==", "<=", ">=" 1330 1331#------------------------------------------------------------------------------ 1332 1333class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): 1334 left = set([1, 2]) 1335 right = set([1, 2]) 1336 name = "equal pair" 1337 cases = "==", "<=", ">=" 1338 1339#------------------------------------------------------------------------------ 1340 1341class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): 1342 left = set() 1343 right = set([1, 2]) 1344 name = "one empty, one non-empty" 1345 cases = "!=", "<", "<=" 1346 1347#------------------------------------------------------------------------------ 1348 1349class TestSubsetPartial(TestSubsets, unittest.TestCase): 1350 left = set([1]) 1351 right = set([1, 2]) 1352 name = "one a non-empty proper subset of other" 1353 cases = "!=", "<", "<=" 1354 1355#------------------------------------------------------------------------------ 1356 1357class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): 1358 left = set([1]) 1359 right = set([2]) 1360 name = "neither empty, neither contains" 1361 cases = "!=" 1362 1363#============================================================================== 1364 1365class TestOnlySetsInBinaryOps: 1366 1367 def test_eq_ne(self): 1368 # Unlike the others, this is testing that == and != *are* allowed. 1369 self.assertEqual(self.other == self.set, False) 1370 self.assertEqual(self.set == self.other, False) 1371 self.assertEqual(self.other != self.set, True) 1372 self.assertEqual(self.set != self.other, True) 1373 1374 def test_ge_gt_le_lt(self): 1375 self.assertRaises(TypeError, lambda: self.set < self.other) 1376 self.assertRaises(TypeError, lambda: self.set <= self.other) 1377 self.assertRaises(TypeError, lambda: self.set > self.other) 1378 self.assertRaises(TypeError, lambda: self.set >= self.other) 1379 1380 self.assertRaises(TypeError, lambda: self.other < self.set) 1381 self.assertRaises(TypeError, lambda: self.other <= self.set) 1382 self.assertRaises(TypeError, lambda: self.other > self.set) 1383 self.assertRaises(TypeError, lambda: self.other >= self.set) 1384 1385 def test_update_operator(self): 1386 try: 1387 self.set |= self.other 1388 except TypeError: 1389 pass 1390 else: 1391 self.fail("expected TypeError") 1392 1393 def test_update(self): 1394 if self.otherIsIterable: 1395 self.set.update(self.other) 1396 else: 1397 self.assertRaises(TypeError, self.set.update, self.other) 1398 1399 def test_union(self): 1400 self.assertRaises(TypeError, lambda: self.set | self.other) 1401 self.assertRaises(TypeError, lambda: self.other | self.set) 1402 if self.otherIsIterable: 1403 self.set.union(self.other) 1404 else: 1405 self.assertRaises(TypeError, self.set.union, self.other) 1406 1407 def test_intersection_update_operator(self): 1408 try: 1409 self.set &= self.other 1410 except TypeError: 1411 pass 1412 else: 1413 self.fail("expected TypeError") 1414 1415 def test_intersection_update(self): 1416 if self.otherIsIterable: 1417 self.set.intersection_update(self.other) 1418 else: 1419 self.assertRaises(TypeError, 1420 self.set.intersection_update, 1421 self.other) 1422 1423 def test_intersection(self): 1424 self.assertRaises(TypeError, lambda: self.set & self.other) 1425 self.assertRaises(TypeError, lambda: self.other & self.set) 1426 if self.otherIsIterable: 1427 self.set.intersection(self.other) 1428 else: 1429 self.assertRaises(TypeError, self.set.intersection, self.other) 1430 1431 def test_sym_difference_update_operator(self): 1432 try: 1433 self.set ^= self.other 1434 except TypeError: 1435 pass 1436 else: 1437 self.fail("expected TypeError") 1438 1439 def test_sym_difference_update(self): 1440 if self.otherIsIterable: 1441 self.set.symmetric_difference_update(self.other) 1442 else: 1443 self.assertRaises(TypeError, 1444 self.set.symmetric_difference_update, 1445 self.other) 1446 1447 def test_sym_difference(self): 1448 self.assertRaises(TypeError, lambda: self.set ^ self.other) 1449 self.assertRaises(TypeError, lambda: self.other ^ self.set) 1450 if self.otherIsIterable: 1451 self.set.symmetric_difference(self.other) 1452 else: 1453 self.assertRaises(TypeError, self.set.symmetric_difference, self.other) 1454 1455 def test_difference_update_operator(self): 1456 try: 1457 self.set -= self.other 1458 except TypeError: 1459 pass 1460 else: 1461 self.fail("expected TypeError") 1462 1463 def test_difference_update(self): 1464 if self.otherIsIterable: 1465 self.set.difference_update(self.other) 1466 else: 1467 self.assertRaises(TypeError, 1468 self.set.difference_update, 1469 self.other) 1470 1471 def test_difference(self): 1472 self.assertRaises(TypeError, lambda: self.set - self.other) 1473 self.assertRaises(TypeError, lambda: self.other - self.set) 1474 if self.otherIsIterable: 1475 self.set.difference(self.other) 1476 else: 1477 self.assertRaises(TypeError, self.set.difference, self.other) 1478 1479#------------------------------------------------------------------------------ 1480 1481class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase): 1482 def setUp(self): 1483 self.set = set((1, 2, 3)) 1484 self.other = 19 1485 self.otherIsIterable = False 1486 1487#------------------------------------------------------------------------------ 1488 1489class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase): 1490 def setUp(self): 1491 self.set = set((1, 2, 3)) 1492 self.other = {1:2, 3:4} 1493 self.otherIsIterable = True 1494 1495#------------------------------------------------------------------------------ 1496 1497class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase): 1498 def setUp(self): 1499 self.set = set((1, 2, 3)) 1500 self.other = operator.add 1501 self.otherIsIterable = False 1502 1503#------------------------------------------------------------------------------ 1504 1505class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase): 1506 def setUp(self): 1507 self.set = set((1, 2, 3)) 1508 self.other = (2, 4, 6) 1509 self.otherIsIterable = True 1510 1511#------------------------------------------------------------------------------ 1512 1513class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase): 1514 def setUp(self): 1515 self.set = set((1, 2, 3)) 1516 self.other = 'abc' 1517 self.otherIsIterable = True 1518 1519#------------------------------------------------------------------------------ 1520 1521class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): 1522 def setUp(self): 1523 def gen(): 1524 for i in range(0, 10, 2): 1525 yield i 1526 self.set = set((1, 2, 3)) 1527 self.other = gen() 1528 self.otherIsIterable = True 1529 1530#============================================================================== 1531 1532class TestCopying: 1533 1534 def test_copy(self): 1535 dup = self.set.copy() 1536 dup_list = sorted(dup, key=repr) 1537 set_list = sorted(self.set, key=repr) 1538 self.assertEqual(len(dup_list), len(set_list)) 1539 for i in range(len(dup_list)): 1540 self.assertTrue(dup_list[i] is set_list[i]) 1541 1542 def test_deep_copy(self): 1543 dup = copy.deepcopy(self.set) 1544 ##print type(dup), repr(dup) 1545 dup_list = sorted(dup, key=repr) 1546 set_list = sorted(self.set, key=repr) 1547 self.assertEqual(len(dup_list), len(set_list)) 1548 for i in range(len(dup_list)): 1549 self.assertEqual(dup_list[i], set_list[i]) 1550 1551#------------------------------------------------------------------------------ 1552 1553class TestCopyingEmpty(TestCopying, unittest.TestCase): 1554 def setUp(self): 1555 self.set = set() 1556 1557#------------------------------------------------------------------------------ 1558 1559class TestCopyingSingleton(TestCopying, unittest.TestCase): 1560 def setUp(self): 1561 self.set = set(["hello"]) 1562 1563#------------------------------------------------------------------------------ 1564 1565class TestCopyingTriple(TestCopying, unittest.TestCase): 1566 def setUp(self): 1567 self.set = set(["zero", 0, None]) 1568 1569#------------------------------------------------------------------------------ 1570 1571class TestCopyingTuple(TestCopying, unittest.TestCase): 1572 def setUp(self): 1573 self.set = set([(1, 2)]) 1574 1575#------------------------------------------------------------------------------ 1576 1577class TestCopyingNested(TestCopying, unittest.TestCase): 1578 def setUp(self): 1579 self.set = set([((1, 2), (3, 4))]) 1580 1581#============================================================================== 1582 1583class TestIdentities(unittest.TestCase): 1584 def setUp(self): 1585 self.a = set('abracadabra') 1586 self.b = set('alacazam') 1587 1588 def test_binopsVsSubsets(self): 1589 a, b = self.a, self.b 1590 self.assertTrue(a - b < a) 1591 self.assertTrue(b - a < b) 1592 self.assertTrue(a & b < a) 1593 self.assertTrue(a & b < b) 1594 self.assertTrue(a | b > a) 1595 self.assertTrue(a | b > b) 1596 self.assertTrue(a ^ b < a | b) 1597 1598 def test_commutativity(self): 1599 a, b = self.a, self.b 1600 self.assertEqual(a&b, b&a) 1601 self.assertEqual(a|b, b|a) 1602 self.assertEqual(a^b, b^a) 1603 if a != b: 1604 self.assertNotEqual(a-b, b-a) 1605 1606 def test_summations(self): 1607 # check that sums of parts equal the whole 1608 a, b = self.a, self.b 1609 self.assertEqual((a-b)|(a&b)|(b-a), a|b) 1610 self.assertEqual((a&b)|(a^b), a|b) 1611 self.assertEqual(a|(b-a), a|b) 1612 self.assertEqual((a-b)|b, a|b) 1613 self.assertEqual((a-b)|(a&b), a) 1614 self.assertEqual((b-a)|(a&b), b) 1615 self.assertEqual((a-b)|(b-a), a^b) 1616 1617 def test_exclusion(self): 1618 # check that inverse operations show non-overlap 1619 a, b, zero = self.a, self.b, set() 1620 self.assertEqual((a-b)&b, zero) 1621 self.assertEqual((b-a)&a, zero) 1622 self.assertEqual((a&b)&(a^b), zero) 1623 1624# Tests derived from test_itertools.py ======================================= 1625 1626def R(seqn): 1627 'Regular generator' 1628 for i in seqn: 1629 yield i 1630 1631class G: 1632 'Sequence using __getitem__' 1633 def __init__(self, seqn): 1634 self.seqn = seqn 1635 def __getitem__(self, i): 1636 return self.seqn[i] 1637 1638class I: 1639 'Sequence using iterator protocol' 1640 def __init__(self, seqn): 1641 self.seqn = seqn 1642 self.i = 0 1643 def __iter__(self): 1644 return self 1645 def __next__(self): 1646 if self.i >= len(self.seqn): raise StopIteration 1647 v = self.seqn[self.i] 1648 self.i += 1 1649 return v 1650 1651class Ig: 1652 'Sequence using iterator protocol defined with a generator' 1653 def __init__(self, seqn): 1654 self.seqn = seqn 1655 self.i = 0 1656 def __iter__(self): 1657 for val in self.seqn: 1658 yield val 1659 1660class X: 1661 'Missing __getitem__ and __iter__' 1662 def __init__(self, seqn): 1663 self.seqn = seqn 1664 self.i = 0 1665 def __next__(self): 1666 if self.i >= len(self.seqn): raise StopIteration 1667 v = self.seqn[self.i] 1668 self.i += 1 1669 return v 1670 1671class N: 1672 'Iterator missing __next__()' 1673 def __init__(self, seqn): 1674 self.seqn = seqn 1675 self.i = 0 1676 def __iter__(self): 1677 return self 1678 1679class E: 1680 'Test propagation of exceptions' 1681 def __init__(self, seqn): 1682 self.seqn = seqn 1683 self.i = 0 1684 def __iter__(self): 1685 return self 1686 def __next__(self): 1687 3 // 0 1688 1689class S: 1690 'Test immediate stop' 1691 def __init__(self, seqn): 1692 pass 1693 def __iter__(self): 1694 return self 1695 def __next__(self): 1696 raise StopIteration 1697 1698from itertools import chain 1699def L(seqn): 1700 'Test multiple tiers of iterators' 1701 return chain(map(lambda x:x, R(Ig(G(seqn))))) 1702 1703class TestVariousIteratorArgs(unittest.TestCase): 1704 1705 def test_constructor(self): 1706 for cons in (set, frozenset): 1707 for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): 1708 for g in (G, I, Ig, S, L, R): 1709 self.assertEqual(sorted(cons(g(s)), key=repr), sorted(g(s), key=repr)) 1710 self.assertRaises(TypeError, cons , X(s)) 1711 self.assertRaises(TypeError, cons , N(s)) 1712 self.assertRaises(ZeroDivisionError, cons , E(s)) 1713 1714 def test_inline_methods(self): 1715 s = set('november') 1716 for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'): 1717 for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint): 1718 for g in (G, I, Ig, L, R): 1719 expected = meth(data) 1720 actual = meth(g(data)) 1721 if isinstance(expected, bool): 1722 self.assertEqual(actual, expected) 1723 else: 1724 self.assertEqual(sorted(actual, key=repr), sorted(expected, key=repr)) 1725 self.assertRaises(TypeError, meth, X(s)) 1726 self.assertRaises(TypeError, meth, N(s)) 1727 self.assertRaises(ZeroDivisionError, meth, E(s)) 1728 1729 def test_inplace_methods(self): 1730 for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'): 1731 for methname in ('update', 'intersection_update', 1732 'difference_update', 'symmetric_difference_update'): 1733 for g in (G, I, Ig, S, L, R): 1734 s = set('january') 1735 t = s.copy() 1736 getattr(s, methname)(list(g(data))) 1737 getattr(t, methname)(g(data)) 1738 self.assertEqual(sorted(s, key=repr), sorted(t, key=repr)) 1739 1740 self.assertRaises(TypeError, getattr(set('january'), methname), X(data)) 1741 self.assertRaises(TypeError, getattr(set('january'), methname), N(data)) 1742 self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data)) 1743 1744class bad_eq: 1745 def __eq__(self, other): 1746 if be_bad: 1747 set2.clear() 1748 raise ZeroDivisionError 1749 return self is other 1750 def __hash__(self): 1751 return 0 1752 1753class bad_dict_clear: 1754 def __eq__(self, other): 1755 if be_bad: 1756 dict2.clear() 1757 return self is other 1758 def __hash__(self): 1759 return 0 1760 1761class TestWeirdBugs(unittest.TestCase): 1762 def test_8420_set_merge(self): 1763 # This used to segfault 1764 global be_bad, set2, dict2 1765 be_bad = False 1766 set1 = {bad_eq()} 1767 set2 = {bad_eq() for i in range(75)} 1768 be_bad = True 1769 self.assertRaises(ZeroDivisionError, set1.update, set2) 1770 1771 be_bad = False 1772 set1 = {bad_dict_clear()} 1773 dict2 = {bad_dict_clear(): None} 1774 be_bad = True 1775 set1.symmetric_difference_update(dict2) 1776 1777 def test_iter_and_mutate(self): 1778 # Issue #24581 1779 s = set(range(100)) 1780 s.clear() 1781 s.update(range(100)) 1782 si = iter(s) 1783 s.clear() 1784 a = list(range(100)) 1785 s.update(range(100)) 1786 list(si) 1787 1788 def test_merge_and_mutate(self): 1789 class X: 1790 def __hash__(self): 1791 return hash(0) 1792 def __eq__(self, o): 1793 other.clear() 1794 return False 1795 1796 other = set() 1797 other = {X() for i in range(10)} 1798 s = {0} 1799 s.update(other) 1800 1801# Application tests (based on David Eppstein's graph recipes ==================================== 1802 1803def powerset(U): 1804 """Generates all subsets of a set or sequence U.""" 1805 U = iter(U) 1806 try: 1807 x = frozenset([next(U)]) 1808 for S in powerset(U): 1809 yield S 1810 yield S | x 1811 except StopIteration: 1812 yield frozenset() 1813 1814def cube(n): 1815 """Graph of n-dimensional hypercube.""" 1816 singletons = [frozenset([x]) for x in range(n)] 1817 return dict([(x, frozenset([x^s for s in singletons])) 1818 for x in powerset(range(n))]) 1819 1820def linegraph(G): 1821 """Graph, the vertices of which are edges of G, 1822 with two vertices being adjacent iff the corresponding 1823 edges share a vertex.""" 1824 L = {} 1825 for x in G: 1826 for y in G[x]: 1827 nx = [frozenset([x,z]) for z in G[x] if z != y] 1828 ny = [frozenset([y,z]) for z in G[y] if z != x] 1829 L[frozenset([x,y])] = frozenset(nx+ny) 1830 return L 1831 1832def faces(G): 1833 'Return a set of faces in G. Where a face is a set of vertices on that face' 1834 # currently limited to triangles,squares, and pentagons 1835 f = set() 1836 for v1, edges in G.items(): 1837 for v2 in edges: 1838 for v3 in G[v2]: 1839 if v1 == v3: 1840 continue 1841 if v1 in G[v3]: 1842 f.add(frozenset([v1, v2, v3])) 1843 else: 1844 for v4 in G[v3]: 1845 if v4 == v2: 1846 continue 1847 if v1 in G[v4]: 1848 f.add(frozenset([v1, v2, v3, v4])) 1849 else: 1850 for v5 in G[v4]: 1851 if v5 == v3 or v5 == v2: 1852 continue 1853 if v1 in G[v5]: 1854 f.add(frozenset([v1, v2, v3, v4, v5])) 1855 return f 1856 1857 1858class TestGraphs(unittest.TestCase): 1859 1860 def test_cube(self): 1861 1862 g = cube(3) # vert --> {v1, v2, v3} 1863 vertices1 = set(g) 1864 self.assertEqual(len(vertices1), 8) # eight vertices 1865 for edge in g.values(): 1866 self.assertEqual(len(edge), 3) # each vertex connects to three edges 1867 vertices2 = set(v for edges in g.values() for v in edges) 1868 self.assertEqual(vertices1, vertices2) # edge vertices in original set 1869 1870 cubefaces = faces(g) 1871 self.assertEqual(len(cubefaces), 6) # six faces 1872 for face in cubefaces: 1873 self.assertEqual(len(face), 4) # each face is a square 1874 1875 def test_cuboctahedron(self): 1876 1877 # http://en.wikipedia.org/wiki/Cuboctahedron 1878 # 8 triangular faces and 6 square faces 1879 # 12 identical vertices each connecting a triangle and square 1880 1881 g = cube(3) 1882 cuboctahedron = linegraph(g) # V( --> {V1, V2, V3, V4} 1883 self.assertEqual(len(cuboctahedron), 12)# twelve vertices 1884 1885 vertices = set(cuboctahedron) 1886 for edges in cuboctahedron.values(): 1887 self.assertEqual(len(edges), 4) # each vertex connects to four other vertices 1888 othervertices = set(edge for edges in cuboctahedron.values() for edge in edges) 1889 self.assertEqual(vertices, othervertices) # edge vertices in original set 1890 1891 cubofaces = faces(cuboctahedron) 1892 facesizes = collections.defaultdict(int) 1893 for face in cubofaces: 1894 facesizes[len(face)] += 1 1895 self.assertEqual(facesizes[3], 8) # eight triangular faces 1896 self.assertEqual(facesizes[4], 6) # six square faces 1897 1898 for vertex in cuboctahedron: 1899 edge = vertex # Cuboctahedron vertices are edges in Cube 1900 self.assertEqual(len(edge), 2) # Two cube vertices define an edge 1901 for cubevert in edge: 1902 self.assertIn(cubevert, g) 1903 1904 1905#============================================================================== 1906 1907if __name__ == "__main__": 1908 unittest.main() 1909