• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from test import support
3import gc
4import weakref
5import operator
6import copy
7import pickle
8from random import randrange, shuffle
9import warnings
10import collections
11import collections.abc
12import itertools
13
14class PassThru(Exception):
15    pass
16
17def check_pass_thru():
18    raise PassThru
19    yield 1
20
21class BadCmp:
22    def __hash__(self):
23        return 1
24    def __eq__(self, other):
25        raise RuntimeError
26
27class ReprWrapper:
28    'Used to test self-referential repr() calls'
29    def __repr__(self):
30        return repr(self.value)
31
32class HashCountingInt(int):
33    'int-like object that counts the number of times __hash__ is called'
34    def __init__(self, *args):
35        self.hash_count = 0
36    def __hash__(self):
37        self.hash_count += 1
38        return int.__hash__(self)
39
40class TestJointOps:
41    # Tests common to both set and frozenset
42
43    def setUp(self):
44        self.word = word = 'simsalabim'
45        self.otherword = 'madagascar'
46        self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
47        self.s = self.thetype(word)
48        self.d = dict.fromkeys(word)
49
50    def test_new_or_init(self):
51        self.assertRaises(TypeError, self.thetype, [], 2)
52        self.assertRaises(TypeError, set().__init__, a=1)
53
54    def test_uniquification(self):
55        actual = sorted(self.s)
56        expected = sorted(self.d)
57        self.assertEqual(actual, expected)
58        self.assertRaises(PassThru, self.thetype, check_pass_thru())
59        self.assertRaises(TypeError, self.thetype, [[]])
60
61    def test_len(self):
62        self.assertEqual(len(self.s), len(self.d))
63
64    def test_contains(self):
65        for c in self.letters:
66            self.assertEqual(c in self.s, c in self.d)
67        self.assertRaises(TypeError, self.s.__contains__, [[]])
68        s = self.thetype([frozenset(self.letters)])
69        self.assertIn(self.thetype(self.letters), s)
70
71    def test_union(self):
72        u = self.s.union(self.otherword)
73        for c in self.letters:
74            self.assertEqual(c in u, c in self.d or c in self.otherword)
75        self.assertEqual(self.s, self.thetype(self.word))
76        self.assertEqual(type(u), self.basetype)
77        self.assertRaises(PassThru, self.s.union, check_pass_thru())
78        self.assertRaises(TypeError, self.s.union, [[]])
79        for C in set, frozenset, dict.fromkeys, str, list, tuple:
80            self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
81            self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
82            self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
83            self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
84            self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
85
86        # Issue #6573
87        x = self.thetype()
88        self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
89
90    def test_or(self):
91        i = self.s.union(self.otherword)
92        self.assertEqual(self.s | set(self.otherword), i)
93        self.assertEqual(self.s | frozenset(self.otherword), i)
94        try:
95            self.s | self.otherword
96        except TypeError:
97            pass
98        else:
99            self.fail("s|t did not screen-out general iterables")
100
101    def test_intersection(self):
102        i = self.s.intersection(self.otherword)
103        for c in self.letters:
104            self.assertEqual(c in i, c in self.d and c in self.otherword)
105        self.assertEqual(self.s, self.thetype(self.word))
106        self.assertEqual(type(i), self.basetype)
107        self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
108        for C in set, frozenset, dict.fromkeys, str, list, tuple:
109            self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
110            self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
111            self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
112            self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
113            self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
114        s = self.thetype('abcba')
115        z = s.intersection()
116        if self.thetype == frozenset():
117            self.assertEqual(id(s), id(z))
118        else:
119            self.assertNotEqual(id(s), id(z))
120
121    def test_isdisjoint(self):
122        def f(s1, s2):
123            'Pure python equivalent of isdisjoint()'
124            return not set(s1).intersection(s2)
125        for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
126            s1 = self.thetype(larg)
127            for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
128                for C in set, frozenset, dict.fromkeys, str, list, tuple:
129                    s2 = C(rarg)
130                    actual = s1.isdisjoint(s2)
131                    expected = f(s1, s2)
132                    self.assertEqual(actual, expected)
133                    self.assertTrue(actual is True or actual is False)
134
135    def test_and(self):
136        i = self.s.intersection(self.otherword)
137        self.assertEqual(self.s & set(self.otherword), i)
138        self.assertEqual(self.s & frozenset(self.otherword), i)
139        try:
140            self.s & self.otherword
141        except TypeError:
142            pass
143        else:
144            self.fail("s&t did not screen-out general iterables")
145
146    def test_difference(self):
147        i = self.s.difference(self.otherword)
148        for c in self.letters:
149            self.assertEqual(c in i, c in self.d and c not in self.otherword)
150        self.assertEqual(self.s, self.thetype(self.word))
151        self.assertEqual(type(i), self.basetype)
152        self.assertRaises(PassThru, self.s.difference, check_pass_thru())
153        self.assertRaises(TypeError, self.s.difference, [[]])
154        for C in set, frozenset, dict.fromkeys, str, list, tuple:
155            self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
156            self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
157            self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
158            self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
159            self.assertEqual(self.thetype('abcba').difference(), set('abc'))
160            self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
161
162    def test_sub(self):
163        i = self.s.difference(self.otherword)
164        self.assertEqual(self.s - set(self.otherword), i)
165        self.assertEqual(self.s - frozenset(self.otherword), i)
166        try:
167            self.s - self.otherword
168        except TypeError:
169            pass
170        else:
171            self.fail("s-t did not screen-out general iterables")
172
173    def test_symmetric_difference(self):
174        i = self.s.symmetric_difference(self.otherword)
175        for c in self.letters:
176            self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
177        self.assertEqual(self.s, self.thetype(self.word))
178        self.assertEqual(type(i), self.basetype)
179        self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
180        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
181        for C in set, frozenset, dict.fromkeys, str, list, tuple:
182            self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
183            self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
184            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
185            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
186
187    def test_xor(self):
188        i = self.s.symmetric_difference(self.otherword)
189        self.assertEqual(self.s ^ set(self.otherword), i)
190        self.assertEqual(self.s ^ frozenset(self.otherword), i)
191        try:
192            self.s ^ self.otherword
193        except TypeError:
194            pass
195        else:
196            self.fail("s^t did not screen-out general iterables")
197
198    def test_equality(self):
199        self.assertEqual(self.s, set(self.word))
200        self.assertEqual(self.s, frozenset(self.word))
201        self.assertEqual(self.s == self.word, False)
202        self.assertNotEqual(self.s, set(self.otherword))
203        self.assertNotEqual(self.s, frozenset(self.otherword))
204        self.assertEqual(self.s != self.word, True)
205
206    def test_setOfFrozensets(self):
207        t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
208        s = self.thetype(t)
209        self.assertEqual(len(s), 3)
210
211    def test_sub_and_super(self):
212        p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
213        self.assertTrue(p < q)
214        self.assertTrue(p <= q)
215        self.assertTrue(q <= q)
216        self.assertTrue(q > p)
217        self.assertTrue(q >= p)
218        self.assertFalse(q < r)
219        self.assertFalse(q <= r)
220        self.assertFalse(q > r)
221        self.assertFalse(q >= r)
222        self.assertTrue(set('a').issubset('abc'))
223        self.assertTrue(set('abc').issuperset('a'))
224        self.assertFalse(set('a').issubset('cbs'))
225        self.assertFalse(set('cbs').issuperset('a'))
226
227    def test_pickling(self):
228        for i in range(pickle.HIGHEST_PROTOCOL + 1):
229            p = pickle.dumps(self.s, i)
230            dup = pickle.loads(p)
231            self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
232            if type(self.s) not in (set, frozenset):
233                self.s.x = 10
234                p = pickle.dumps(self.s, i)
235                dup = pickle.loads(p)
236                self.assertEqual(self.s.x, dup.x)
237
238    def test_iterator_pickling(self):
239        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
240            itorg = iter(self.s)
241            data = self.thetype(self.s)
242            d = pickle.dumps(itorg, proto)
243            it = pickle.loads(d)
244            # Set iterators unpickle as list iterators due to the
245            # undefined order of set items.
246            # self.assertEqual(type(itorg), type(it))
247            self.assertIsInstance(it, collections.abc.Iterator)
248            self.assertEqual(self.thetype(it), data)
249
250            it = pickle.loads(d)
251            try:
252                drop = next(it)
253            except StopIteration:
254                continue
255            d = pickle.dumps(it, proto)
256            it = pickle.loads(d)
257            self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
258
259    def test_deepcopy(self):
260        class Tracer:
261            def __init__(self, value):
262                self.value = value
263            def __hash__(self):
264                return self.value
265            def __deepcopy__(self, memo=None):
266                return Tracer(self.value + 1)
267        t = Tracer(10)
268        s = self.thetype([t])
269        dup = copy.deepcopy(s)
270        self.assertNotEqual(id(s), id(dup))
271        for elem in dup:
272            newt = elem
273        self.assertNotEqual(id(t), id(newt))
274        self.assertEqual(t.value + 1, newt.value)
275
276    def test_gc(self):
277        # Create a nest of cycles to exercise overall ref count check
278        class A:
279            pass
280        s = set(A() for i in range(1000))
281        for elem in s:
282            elem.cycle = s
283            elem.sub = elem
284            elem.set = set([elem])
285
286    def test_subclass_with_custom_hash(self):
287        # Bug #1257731
288        class H(self.thetype):
289            def __hash__(self):
290                return int(id(self) & 0x7fffffff)
291        s=H()
292        f=set()
293        f.add(s)
294        self.assertIn(s, f)
295        f.remove(s)
296        f.add(s)
297        f.discard(s)
298
299    def test_badcmp(self):
300        s = self.thetype([BadCmp()])
301        # Detect comparison errors during insertion and lookup
302        self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()])
303        self.assertRaises(RuntimeError, s.__contains__, BadCmp())
304        # Detect errors during mutating operations
305        if hasattr(s, 'add'):
306            self.assertRaises(RuntimeError, s.add, BadCmp())
307            self.assertRaises(RuntimeError, s.discard, BadCmp())
308            self.assertRaises(RuntimeError, s.remove, BadCmp())
309
310    def test_cyclical_repr(self):
311        w = ReprWrapper()
312        s = self.thetype([w])
313        w.value = s
314        if self.thetype == set:
315            self.assertEqual(repr(s), '{set(...)}')
316        else:
317            name = repr(s).partition('(')[0]    # strip class name
318            self.assertEqual(repr(s), '%s({%s(...)})' % (name, name))
319
320    def test_cyclical_print(self):
321        w = ReprWrapper()
322        s = self.thetype([w])
323        w.value = s
324        fo = open(support.TESTFN, "w")
325        try:
326            fo.write(str(s))
327            fo.close()
328            fo = open(support.TESTFN, "r")
329            self.assertEqual(fo.read(), repr(s))
330        finally:
331            fo.close()
332            support.unlink(support.TESTFN)
333
334    def test_do_not_rehash_dict_keys(self):
335        n = 10
336        d = dict.fromkeys(map(HashCountingInt, range(n)))
337        self.assertEqual(sum(elem.hash_count for elem in d), n)
338        s = self.thetype(d)
339        self.assertEqual(sum(elem.hash_count for elem in d), n)
340        s.difference(d)
341        self.assertEqual(sum(elem.hash_count for elem in d), n)
342        if hasattr(s, 'symmetric_difference_update'):
343            s.symmetric_difference_update(d)
344        self.assertEqual(sum(elem.hash_count for elem in d), n)
345        d2 = dict.fromkeys(set(d))
346        self.assertEqual(sum(elem.hash_count for elem in d), n)
347        d3 = dict.fromkeys(frozenset(d))
348        self.assertEqual(sum(elem.hash_count for elem in d), n)
349        d3 = dict.fromkeys(frozenset(d), 123)
350        self.assertEqual(sum(elem.hash_count for elem in d), n)
351        self.assertEqual(d3, dict.fromkeys(d, 123))
352
353    def test_container_iterator(self):
354        # Bug #3680: tp_traverse was not implemented for set iterator object
355        class C(object):
356            pass
357        obj = C()
358        ref = weakref.ref(obj)
359        container = set([obj, 1])
360        obj.x = iter(container)
361        del obj, container
362        gc.collect()
363        self.assertTrue(ref() is None, "Cycle was not collected")
364
365    def test_free_after_iterating(self):
366        support.check_free_after_iterating(self, iter, self.thetype)
367
368class TestSet(TestJointOps, unittest.TestCase):
369    thetype = set
370    basetype = set
371
372    def test_init(self):
373        s = self.thetype()
374        s.__init__(self.word)
375        self.assertEqual(s, set(self.word))
376        s.__init__(self.otherword)
377        self.assertEqual(s, set(self.otherword))
378        self.assertRaises(TypeError, s.__init__, s, 2);
379        self.assertRaises(TypeError, s.__init__, 1);
380
381    def test_constructor_identity(self):
382        s = self.thetype(range(3))
383        t = self.thetype(s)
384        self.assertNotEqual(id(s), id(t))
385
386    def test_set_literal(self):
387        s = set([1,2,3])
388        t = {1,2,3}
389        self.assertEqual(s, t)
390
391    def test_set_literal_insertion_order(self):
392        # SF Issue #26020 -- Expect left to right insertion
393        s = {1, 1.0, True}
394        self.assertEqual(len(s), 1)
395        stored_value = s.pop()
396        self.assertEqual(type(stored_value), int)
397
398    def test_set_literal_evaluation_order(self):
399        # Expect left to right expression evaluation
400        events = []
401        def record(obj):
402            events.append(obj)
403        s = {record(1), record(2), record(3)}
404        self.assertEqual(events, [1, 2, 3])
405
406    def test_hash(self):
407        self.assertRaises(TypeError, hash, self.s)
408
409    def test_clear(self):
410        self.s.clear()
411        self.assertEqual(self.s, set())
412        self.assertEqual(len(self.s), 0)
413
414    def test_copy(self):
415        dup = self.s.copy()
416        self.assertEqual(self.s, dup)
417        self.assertNotEqual(id(self.s), id(dup))
418        self.assertEqual(type(dup), self.basetype)
419
420    def test_add(self):
421        self.s.add('Q')
422        self.assertIn('Q', self.s)
423        dup = self.s.copy()
424        self.s.add('Q')
425        self.assertEqual(self.s, dup)
426        self.assertRaises(TypeError, self.s.add, [])
427
428    def test_remove(self):
429        self.s.remove('a')
430        self.assertNotIn('a', self.s)
431        self.assertRaises(KeyError, self.s.remove, 'Q')
432        self.assertRaises(TypeError, self.s.remove, [])
433        s = self.thetype([frozenset(self.word)])
434        self.assertIn(self.thetype(self.word), s)
435        s.remove(self.thetype(self.word))
436        self.assertNotIn(self.thetype(self.word), s)
437        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
438
439    def test_remove_keyerror_unpacking(self):
440        # bug:  www.python.org/sf/1576657
441        for v1 in ['Q', (1,)]:
442            try:
443                self.s.remove(v1)
444            except KeyError as e:
445                v2 = e.args[0]
446                self.assertEqual(v1, v2)
447            else:
448                self.fail()
449
450    def test_remove_keyerror_set(self):
451        key = self.thetype([3, 4])
452        try:
453            self.s.remove(key)
454        except KeyError as e:
455            self.assertTrue(e.args[0] is key,
456                         "KeyError should be {0}, not {1}".format(key,
457                                                                  e.args[0]))
458        else:
459            self.fail()
460
461    def test_discard(self):
462        self.s.discard('a')
463        self.assertNotIn('a', self.s)
464        self.s.discard('Q')
465        self.assertRaises(TypeError, self.s.discard, [])
466        s = self.thetype([frozenset(self.word)])
467        self.assertIn(self.thetype(self.word), s)
468        s.discard(self.thetype(self.word))
469        self.assertNotIn(self.thetype(self.word), s)
470        s.discard(self.thetype(self.word))
471
472    def test_pop(self):
473        for i in range(len(self.s)):
474            elem = self.s.pop()
475            self.assertNotIn(elem, self.s)
476        self.assertRaises(KeyError, self.s.pop)
477
478    def test_update(self):
479        retval = self.s.update(self.otherword)
480        self.assertEqual(retval, None)
481        for c in (self.word + self.otherword):
482            self.assertIn(c, self.s)
483        self.assertRaises(PassThru, self.s.update, check_pass_thru())
484        self.assertRaises(TypeError, self.s.update, [[]])
485        for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
486            for C in set, frozenset, dict.fromkeys, str, list, tuple:
487                s = self.thetype('abcba')
488                self.assertEqual(s.update(C(p)), None)
489                self.assertEqual(s, set(q))
490        for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
491            q = 'ahi'
492            for C in set, frozenset, dict.fromkeys, str, list, tuple:
493                s = self.thetype('abcba')
494                self.assertEqual(s.update(C(p), C(q)), None)
495                self.assertEqual(s, set(s) | set(p) | set(q))
496
497    def test_ior(self):
498        self.s |= set(self.otherword)
499        for c in (self.word + self.otherword):
500            self.assertIn(c, self.s)
501
502    def test_intersection_update(self):
503        retval = self.s.intersection_update(self.otherword)
504        self.assertEqual(retval, None)
505        for c in (self.word + self.otherword):
506            if c in self.otherword and c in self.word:
507                self.assertIn(c, self.s)
508            else:
509                self.assertNotIn(c, self.s)
510        self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
511        self.assertRaises(TypeError, self.s.intersection_update, [[]])
512        for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
513            for C in set, frozenset, dict.fromkeys, str, list, tuple:
514                s = self.thetype('abcba')
515                self.assertEqual(s.intersection_update(C(p)), None)
516                self.assertEqual(s, set(q))
517                ss = 'abcba'
518                s = self.thetype(ss)
519                t = 'cbc'
520                self.assertEqual(s.intersection_update(C(p), C(t)), None)
521                self.assertEqual(s, set('abcba')&set(p)&set(t))
522
523    def test_iand(self):
524        self.s &= set(self.otherword)
525        for c in (self.word + self.otherword):
526            if c in self.otherword and c in self.word:
527                self.assertIn(c, self.s)
528            else:
529                self.assertNotIn(c, self.s)
530
531    def test_difference_update(self):
532        retval = self.s.difference_update(self.otherword)
533        self.assertEqual(retval, None)
534        for c in (self.word + self.otherword):
535            if c in self.word and c not in self.otherword:
536                self.assertIn(c, self.s)
537            else:
538                self.assertNotIn(c, self.s)
539        self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
540        self.assertRaises(TypeError, self.s.difference_update, [[]])
541        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
542        for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
543            for C in set, frozenset, dict.fromkeys, str, list, tuple:
544                s = self.thetype('abcba')
545                self.assertEqual(s.difference_update(C(p)), None)
546                self.assertEqual(s, set(q))
547
548                s = self.thetype('abcdefghih')
549                s.difference_update()
550                self.assertEqual(s, self.thetype('abcdefghih'))
551
552                s = self.thetype('abcdefghih')
553                s.difference_update(C('aba'))
554                self.assertEqual(s, self.thetype('cdefghih'))
555
556                s = self.thetype('abcdefghih')
557                s.difference_update(C('cdc'), C('aba'))
558                self.assertEqual(s, self.thetype('efghih'))
559
560    def test_isub(self):
561        self.s -= set(self.otherword)
562        for c in (self.word + self.otherword):
563            if c in self.word and c not in self.otherword:
564                self.assertIn(c, self.s)
565            else:
566                self.assertNotIn(c, self.s)
567
568    def test_symmetric_difference_update(self):
569        retval = self.s.symmetric_difference_update(self.otherword)
570        self.assertEqual(retval, None)
571        for c in (self.word + self.otherword):
572            if (c in self.word) ^ (c in self.otherword):
573                self.assertIn(c, self.s)
574            else:
575                self.assertNotIn(c, self.s)
576        self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
577        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
578        for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
579            for C in set, frozenset, dict.fromkeys, str, list, tuple:
580                s = self.thetype('abcba')
581                self.assertEqual(s.symmetric_difference_update(C(p)), None)
582                self.assertEqual(s, set(q))
583
584    def test_ixor(self):
585        self.s ^= set(self.otherword)
586        for c in (self.word + self.otherword):
587            if (c in self.word) ^ (c in self.otherword):
588                self.assertIn(c, self.s)
589            else:
590                self.assertNotIn(c, self.s)
591
592    def test_inplace_on_self(self):
593        t = self.s.copy()
594        t |= t
595        self.assertEqual(t, self.s)
596        t &= t
597        self.assertEqual(t, self.s)
598        t -= t
599        self.assertEqual(t, self.thetype())
600        t = self.s.copy()
601        t ^= t
602        self.assertEqual(t, self.thetype())
603
604    def test_weakref(self):
605        s = self.thetype('gallahad')
606        p = weakref.proxy(s)
607        self.assertEqual(str(p), str(s))
608        s = None
609        self.assertRaises(ReferenceError, str, p)
610
611    def test_rich_compare(self):
612        class TestRichSetCompare:
613            def __gt__(self, some_set):
614                self.gt_called = True
615                return False
616            def __lt__(self, some_set):
617                self.lt_called = True
618                return False
619            def __ge__(self, some_set):
620                self.ge_called = True
621                return False
622            def __le__(self, some_set):
623                self.le_called = True
624                return False
625
626        # This first tries the builtin rich set comparison, which doesn't know
627        # how to handle the custom object. Upon returning NotImplemented, the
628        # corresponding comparison on the right object is invoked.
629        myset = {1, 2, 3}
630
631        myobj = TestRichSetCompare()
632        myset < myobj
633        self.assertTrue(myobj.gt_called)
634
635        myobj = TestRichSetCompare()
636        myset > myobj
637        self.assertTrue(myobj.lt_called)
638
639        myobj = TestRichSetCompare()
640        myset <= myobj
641        self.assertTrue(myobj.ge_called)
642
643        myobj = TestRichSetCompare()
644        myset >= myobj
645        self.assertTrue(myobj.le_called)
646
647    @unittest.skipUnless(hasattr(set, "test_c_api"),
648                         'C API test only available in a debug build')
649    def test_c_api(self):
650        self.assertEqual(set().test_c_api(), True)
651
652class SetSubclass(set):
653    pass
654
655class TestSetSubclass(TestSet):
656    thetype = SetSubclass
657    basetype = set
658
659class SetSubclassWithKeywordArgs(set):
660    def __init__(self, iterable=[], newarg=None):
661        set.__init__(self, iterable)
662
663class TestSetSubclassWithKeywordArgs(TestSet):
664
665    def test_keywords_in_subclass(self):
666        'SF bug #1486663 -- this used to erroneously raise a TypeError'
667        SetSubclassWithKeywordArgs(newarg=1)
668
669class TestFrozenSet(TestJointOps, unittest.TestCase):
670    thetype = frozenset
671    basetype = frozenset
672
673    def test_init(self):
674        s = self.thetype(self.word)
675        s.__init__(self.otherword)
676        self.assertEqual(s, set(self.word))
677
678    def test_singleton_empty_frozenset(self):
679        f = frozenset()
680        efs = [frozenset(), frozenset([]), frozenset(()), frozenset(''),
681               frozenset(), frozenset([]), frozenset(()), frozenset(''),
682               frozenset(range(0)), frozenset(frozenset()),
683               frozenset(f), f]
684        # All of the empty frozensets should have just one id()
685        self.assertEqual(len(set(map(id, efs))), 1)
686
687    def test_constructor_identity(self):
688        s = self.thetype(range(3))
689        t = self.thetype(s)
690        self.assertEqual(id(s), id(t))
691
692    def test_hash(self):
693        self.assertEqual(hash(self.thetype('abcdeb')),
694                         hash(self.thetype('ebecda')))
695
696        # make sure that all permutations give the same hash value
697        n = 100
698        seq = [randrange(n) for i in range(n)]
699        results = set()
700        for i in range(200):
701            shuffle(seq)
702            results.add(hash(self.thetype(seq)))
703        self.assertEqual(len(results), 1)
704
705    def test_copy(self):
706        dup = self.s.copy()
707        self.assertEqual(id(self.s), id(dup))
708
709    def test_frozen_as_dictkey(self):
710        seq = list(range(10)) + list('abcdefg') + ['apple']
711        key1 = self.thetype(seq)
712        key2 = self.thetype(reversed(seq))
713        self.assertEqual(key1, key2)
714        self.assertNotEqual(id(key1), id(key2))
715        d = {}
716        d[key1] = 42
717        self.assertEqual(d[key2], 42)
718
719    def test_hash_caching(self):
720        f = self.thetype('abcdcda')
721        self.assertEqual(hash(f), hash(f))
722
723    def test_hash_effectiveness(self):
724        n = 13
725        hashvalues = set()
726        addhashvalue = hashvalues.add
727        elemmasks = [(i+1, 1<<i) for i in range(n)]
728        for i in range(2**n):
729            addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
730        self.assertEqual(len(hashvalues), 2**n)
731
732        def zf_range(n):
733            # https://en.wikipedia.org/wiki/Set-theoretic_definition_of_natural_numbers
734            nums = [frozenset()]
735            for i in range(n-1):
736                num = frozenset(nums)
737                nums.append(num)
738            return nums[:n]
739
740        def powerset(s):
741            for i in range(len(s)+1):
742                yield from map(frozenset, itertools.combinations(s, i))
743
744        for n in range(18):
745            t = 2 ** n
746            mask = t - 1
747            for nums in (range, zf_range):
748                u = len({h & mask for h in map(hash, powerset(nums(n)))})
749                self.assertGreater(4*u, t)
750
751class FrozenSetSubclass(frozenset):
752    pass
753
754class TestFrozenSetSubclass(TestFrozenSet):
755    thetype = FrozenSetSubclass
756    basetype = frozenset
757
758    def test_constructor_identity(self):
759        s = self.thetype(range(3))
760        t = self.thetype(s)
761        self.assertNotEqual(id(s), id(t))
762
763    def test_copy(self):
764        dup = self.s.copy()
765        self.assertNotEqual(id(self.s), id(dup))
766
767    def test_nested_empty_constructor(self):
768        s = self.thetype()
769        t = self.thetype(s)
770        self.assertEqual(s, t)
771
772    def test_singleton_empty_frozenset(self):
773        Frozenset = self.thetype
774        f = frozenset()
775        F = Frozenset()
776        efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
777               Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
778               Frozenset(range(0)), Frozenset(Frozenset()),
779               Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)]
780        # All empty frozenset subclass instances should have different ids
781        self.assertEqual(len(set(map(id, efs))), len(efs))
782
783# Tests taken from test_sets.py =============================================
784
785empty_set = set()
786
787#==============================================================================
788
789class TestBasicOps:
790
791    def test_repr(self):
792        if self.repr is not None:
793            self.assertEqual(repr(self.set), self.repr)
794
795    def check_repr_against_values(self):
796        text = repr(self.set)
797        self.assertTrue(text.startswith('{'))
798        self.assertTrue(text.endswith('}'))
799
800        result = text[1:-1].split(', ')
801        result.sort()
802        sorted_repr_values = [repr(value) for value in self.values]
803        sorted_repr_values.sort()
804        self.assertEqual(result, sorted_repr_values)
805
806    def test_print(self):
807        try:
808            fo = open(support.TESTFN, "w")
809            fo.write(str(self.set))
810            fo.close()
811            fo = open(support.TESTFN, "r")
812            self.assertEqual(fo.read(), repr(self.set))
813        finally:
814            fo.close()
815            support.unlink(support.TESTFN)
816
817    def test_length(self):
818        self.assertEqual(len(self.set), self.length)
819
820    def test_self_equality(self):
821        self.assertEqual(self.set, self.set)
822
823    def test_equivalent_equality(self):
824        self.assertEqual(self.set, self.dup)
825
826    def test_copy(self):
827        self.assertEqual(self.set.copy(), self.dup)
828
829    def test_self_union(self):
830        result = self.set | self.set
831        self.assertEqual(result, self.dup)
832
833    def test_empty_union(self):
834        result = self.set | empty_set
835        self.assertEqual(result, self.dup)
836
837    def test_union_empty(self):
838        result = empty_set | self.set
839        self.assertEqual(result, self.dup)
840
841    def test_self_intersection(self):
842        result = self.set & self.set
843        self.assertEqual(result, self.dup)
844
845    def test_empty_intersection(self):
846        result = self.set & empty_set
847        self.assertEqual(result, empty_set)
848
849    def test_intersection_empty(self):
850        result = empty_set & self.set
851        self.assertEqual(result, empty_set)
852
853    def test_self_isdisjoint(self):
854        result = self.set.isdisjoint(self.set)
855        self.assertEqual(result, not self.set)
856
857    def test_empty_isdisjoint(self):
858        result = self.set.isdisjoint(empty_set)
859        self.assertEqual(result, True)
860
861    def test_isdisjoint_empty(self):
862        result = empty_set.isdisjoint(self.set)
863        self.assertEqual(result, True)
864
865    def test_self_symmetric_difference(self):
866        result = self.set ^ self.set
867        self.assertEqual(result, empty_set)
868
869    def test_empty_symmetric_difference(self):
870        result = self.set ^ empty_set
871        self.assertEqual(result, self.set)
872
873    def test_self_difference(self):
874        result = self.set - self.set
875        self.assertEqual(result, empty_set)
876
877    def test_empty_difference(self):
878        result = self.set - empty_set
879        self.assertEqual(result, self.dup)
880
881    def test_empty_difference_rev(self):
882        result = empty_set - self.set
883        self.assertEqual(result, empty_set)
884
885    def test_iteration(self):
886        for v in self.set:
887            self.assertIn(v, self.values)
888        setiter = iter(self.set)
889        self.assertEqual(setiter.__length_hint__(), len(self.set))
890
891    def test_pickling(self):
892        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
893            p = pickle.dumps(self.set, proto)
894            copy = pickle.loads(p)
895            self.assertEqual(self.set, copy,
896                             "%s != %s" % (self.set, copy))
897
898#------------------------------------------------------------------------------
899
900class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
901    def setUp(self):
902        self.case   = "empty set"
903        self.values = []
904        self.set    = set(self.values)
905        self.dup    = set(self.values)
906        self.length = 0
907        self.repr   = "set()"
908
909#------------------------------------------------------------------------------
910
911class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
912    def setUp(self):
913        self.case   = "unit set (number)"
914        self.values = [3]
915        self.set    = set(self.values)
916        self.dup    = set(self.values)
917        self.length = 1
918        self.repr   = "{3}"
919
920    def test_in(self):
921        self.assertIn(3, self.set)
922
923    def test_not_in(self):
924        self.assertNotIn(2, self.set)
925
926#------------------------------------------------------------------------------
927
928class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
929    def setUp(self):
930        self.case   = "unit set (tuple)"
931        self.values = [(0, "zero")]
932        self.set    = set(self.values)
933        self.dup    = set(self.values)
934        self.length = 1
935        self.repr   = "{(0, 'zero')}"
936
937    def test_in(self):
938        self.assertIn((0, "zero"), self.set)
939
940    def test_not_in(self):
941        self.assertNotIn(9, self.set)
942
943#------------------------------------------------------------------------------
944
945class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
946    def setUp(self):
947        self.case   = "triple set"
948        self.values = [0, "zero", operator.add]
949        self.set    = set(self.values)
950        self.dup    = set(self.values)
951        self.length = 3
952        self.repr   = None
953
954#------------------------------------------------------------------------------
955
956class TestBasicOpsString(TestBasicOps, unittest.TestCase):
957    def setUp(self):
958        self.case   = "string set"
959        self.values = ["a", "b", "c"]
960        self.set    = set(self.values)
961        self.dup    = set(self.values)
962        self.length = 3
963
964    def test_repr(self):
965        self.check_repr_against_values()
966
967#------------------------------------------------------------------------------
968
969class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
970    def setUp(self):
971        self.case   = "bytes set"
972        self.values = [b"a", b"b", b"c"]
973        self.set    = set(self.values)
974        self.dup    = set(self.values)
975        self.length = 3
976
977    def test_repr(self):
978        self.check_repr_against_values()
979
980#------------------------------------------------------------------------------
981
982class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
983    def setUp(self):
984        self._warning_filters = support.check_warnings()
985        self._warning_filters.__enter__()
986        warnings.simplefilter('ignore', BytesWarning)
987        self.case   = "string and bytes set"
988        self.values = ["a", "b", b"a", b"b"]
989        self.set    = set(self.values)
990        self.dup    = set(self.values)
991        self.length = 4
992
993    def tearDown(self):
994        self._warning_filters.__exit__(None, None, None)
995
996    def test_repr(self):
997        self.check_repr_against_values()
998
999#==============================================================================
1000
1001def baditer():
1002    raise TypeError
1003    yield True
1004
1005def gooditer():
1006    yield True
1007
1008class TestExceptionPropagation(unittest.TestCase):
1009    """SF 628246:  Set constructor should not trap iterator TypeErrors"""
1010
1011    def test_instanceWithException(self):
1012        self.assertRaises(TypeError, set, baditer())
1013
1014    def test_instancesWithoutException(self):
1015        # All of these iterables should load without exception.
1016        set([1,2,3])
1017        set((1,2,3))
1018        set({'one':1, 'two':2, 'three':3})
1019        set(range(3))
1020        set('abc')
1021        set(gooditer())
1022
1023    def test_changingSizeWhileIterating(self):
1024        s = set([1,2,3])
1025        try:
1026            for i in s:
1027                s.update([4])
1028        except RuntimeError:
1029            pass
1030        else:
1031            self.fail("no exception when changing size during iteration")
1032
1033#==============================================================================
1034
1035class TestSetOfSets(unittest.TestCase):
1036    def test_constructor(self):
1037        inner = frozenset([1])
1038        outer = set([inner])
1039        element = outer.pop()
1040        self.assertEqual(type(element), frozenset)
1041        outer.add(inner)        # Rebuild set of sets with .add method
1042        outer.remove(inner)
1043        self.assertEqual(outer, set())   # Verify that remove worked
1044        outer.discard(inner)    # Absence of KeyError indicates working fine
1045
1046#==============================================================================
1047
1048class TestBinaryOps(unittest.TestCase):
1049    def setUp(self):
1050        self.set = set((2, 4, 6))
1051
1052    def test_eq(self):              # SF bug 643115
1053        self.assertEqual(self.set, set({2:1,4:3,6:5}))
1054
1055    def test_union_subset(self):
1056        result = self.set | set([2])
1057        self.assertEqual(result, set((2, 4, 6)))
1058
1059    def test_union_superset(self):
1060        result = self.set | set([2, 4, 6, 8])
1061        self.assertEqual(result, set([2, 4, 6, 8]))
1062
1063    def test_union_overlap(self):
1064        result = self.set | set([3, 4, 5])
1065        self.assertEqual(result, set([2, 3, 4, 5, 6]))
1066
1067    def test_union_non_overlap(self):
1068        result = self.set | set([8])
1069        self.assertEqual(result, set([2, 4, 6, 8]))
1070
1071    def test_intersection_subset(self):
1072        result = self.set & set((2, 4))
1073        self.assertEqual(result, set((2, 4)))
1074
1075    def test_intersection_superset(self):
1076        result = self.set & set([2, 4, 6, 8])
1077        self.assertEqual(result, set([2, 4, 6]))
1078
1079    def test_intersection_overlap(self):
1080        result = self.set & set([3, 4, 5])
1081        self.assertEqual(result, set([4]))
1082
1083    def test_intersection_non_overlap(self):
1084        result = self.set & set([8])
1085        self.assertEqual(result, empty_set)
1086
1087    def test_isdisjoint_subset(self):
1088        result = self.set.isdisjoint(set((2, 4)))
1089        self.assertEqual(result, False)
1090
1091    def test_isdisjoint_superset(self):
1092        result = self.set.isdisjoint(set([2, 4, 6, 8]))
1093        self.assertEqual(result, False)
1094
1095    def test_isdisjoint_overlap(self):
1096        result = self.set.isdisjoint(set([3, 4, 5]))
1097        self.assertEqual(result, False)
1098
1099    def test_isdisjoint_non_overlap(self):
1100        result = self.set.isdisjoint(set([8]))
1101        self.assertEqual(result, True)
1102
1103    def test_sym_difference_subset(self):
1104        result = self.set ^ set((2, 4))
1105        self.assertEqual(result, set([6]))
1106
1107    def test_sym_difference_superset(self):
1108        result = self.set ^ set((2, 4, 6, 8))
1109        self.assertEqual(result, set([8]))
1110
1111    def test_sym_difference_overlap(self):
1112        result = self.set ^ set((3, 4, 5))
1113        self.assertEqual(result, set([2, 3, 5, 6]))
1114
1115    def test_sym_difference_non_overlap(self):
1116        result = self.set ^ set([8])
1117        self.assertEqual(result, set([2, 4, 6, 8]))
1118
1119#==============================================================================
1120
1121class TestUpdateOps(unittest.TestCase):
1122    def setUp(self):
1123        self.set = set((2, 4, 6))
1124
1125    def test_union_subset(self):
1126        self.set |= set([2])
1127        self.assertEqual(self.set, set((2, 4, 6)))
1128
1129    def test_union_superset(self):
1130        self.set |= set([2, 4, 6, 8])
1131        self.assertEqual(self.set, set([2, 4, 6, 8]))
1132
1133    def test_union_overlap(self):
1134        self.set |= set([3, 4, 5])
1135        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1136
1137    def test_union_non_overlap(self):
1138        self.set |= set([8])
1139        self.assertEqual(self.set, set([2, 4, 6, 8]))
1140
1141    def test_union_method_call(self):
1142        self.set.update(set([3, 4, 5]))
1143        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1144
1145    def test_intersection_subset(self):
1146        self.set &= set((2, 4))
1147        self.assertEqual(self.set, set((2, 4)))
1148
1149    def test_intersection_superset(self):
1150        self.set &= set([2, 4, 6, 8])
1151        self.assertEqual(self.set, set([2, 4, 6]))
1152
1153    def test_intersection_overlap(self):
1154        self.set &= set([3, 4, 5])
1155        self.assertEqual(self.set, set([4]))
1156
1157    def test_intersection_non_overlap(self):
1158        self.set &= set([8])
1159        self.assertEqual(self.set, empty_set)
1160
1161    def test_intersection_method_call(self):
1162        self.set.intersection_update(set([3, 4, 5]))
1163        self.assertEqual(self.set, set([4]))
1164
1165    def test_sym_difference_subset(self):
1166        self.set ^= set((2, 4))
1167        self.assertEqual(self.set, set([6]))
1168
1169    def test_sym_difference_superset(self):
1170        self.set ^= set((2, 4, 6, 8))
1171        self.assertEqual(self.set, set([8]))
1172
1173    def test_sym_difference_overlap(self):
1174        self.set ^= set((3, 4, 5))
1175        self.assertEqual(self.set, set([2, 3, 5, 6]))
1176
1177    def test_sym_difference_non_overlap(self):
1178        self.set ^= set([8])
1179        self.assertEqual(self.set, set([2, 4, 6, 8]))
1180
1181    def test_sym_difference_method_call(self):
1182        self.set.symmetric_difference_update(set([3, 4, 5]))
1183        self.assertEqual(self.set, set([2, 3, 5, 6]))
1184
1185    def test_difference_subset(self):
1186        self.set -= set((2, 4))
1187        self.assertEqual(self.set, set([6]))
1188
1189    def test_difference_superset(self):
1190        self.set -= set((2, 4, 6, 8))
1191        self.assertEqual(self.set, set([]))
1192
1193    def test_difference_overlap(self):
1194        self.set -= set((3, 4, 5))
1195        self.assertEqual(self.set, set([2, 6]))
1196
1197    def test_difference_non_overlap(self):
1198        self.set -= set([8])
1199        self.assertEqual(self.set, set([2, 4, 6]))
1200
1201    def test_difference_method_call(self):
1202        self.set.difference_update(set([3, 4, 5]))
1203        self.assertEqual(self.set, set([2, 6]))
1204
1205#==============================================================================
1206
1207class TestMutate(unittest.TestCase):
1208    def setUp(self):
1209        self.values = ["a", "b", "c"]
1210        self.set = set(self.values)
1211
1212    def test_add_present(self):
1213        self.set.add("c")
1214        self.assertEqual(self.set, set("abc"))
1215
1216    def test_add_absent(self):
1217        self.set.add("d")
1218        self.assertEqual(self.set, set("abcd"))
1219
1220    def test_add_until_full(self):
1221        tmp = set()
1222        expected_len = 0
1223        for v in self.values:
1224            tmp.add(v)
1225            expected_len += 1
1226            self.assertEqual(len(tmp), expected_len)
1227        self.assertEqual(tmp, self.set)
1228
1229    def test_remove_present(self):
1230        self.set.remove("b")
1231        self.assertEqual(self.set, set("ac"))
1232
1233    def test_remove_absent(self):
1234        try:
1235            self.set.remove("d")
1236            self.fail("Removing missing element should have raised LookupError")
1237        except LookupError:
1238            pass
1239
1240    def test_remove_until_empty(self):
1241        expected_len = len(self.set)
1242        for v in self.values:
1243            self.set.remove(v)
1244            expected_len -= 1
1245            self.assertEqual(len(self.set), expected_len)
1246
1247    def test_discard_present(self):
1248        self.set.discard("c")
1249        self.assertEqual(self.set, set("ab"))
1250
1251    def test_discard_absent(self):
1252        self.set.discard("d")
1253        self.assertEqual(self.set, set("abc"))
1254
1255    def test_clear(self):
1256        self.set.clear()
1257        self.assertEqual(len(self.set), 0)
1258
1259    def test_pop(self):
1260        popped = {}
1261        while self.set:
1262            popped[self.set.pop()] = None
1263        self.assertEqual(len(popped), len(self.values))
1264        for v in self.values:
1265            self.assertIn(v, popped)
1266
1267    def test_update_empty_tuple(self):
1268        self.set.update(())
1269        self.assertEqual(self.set, set(self.values))
1270
1271    def test_update_unit_tuple_overlap(self):
1272        self.set.update(("a",))
1273        self.assertEqual(self.set, set(self.values))
1274
1275    def test_update_unit_tuple_non_overlap(self):
1276        self.set.update(("a", "z"))
1277        self.assertEqual(self.set, set(self.values + ["z"]))
1278
1279#==============================================================================
1280
1281class TestSubsets:
1282
1283    case2method = {"<=": "issubset",
1284                   ">=": "issuperset",
1285                  }
1286
1287    reverse = {"==": "==",
1288               "!=": "!=",
1289               "<":  ">",
1290               ">":  "<",
1291               "<=": ">=",
1292               ">=": "<=",
1293              }
1294
1295    def test_issubset(self):
1296        x = self.left
1297        y = self.right
1298        for case in "!=", "==", "<", "<=", ">", ">=":
1299            expected = case in self.cases
1300            # Test the binary infix spelling.
1301            result = eval("x" + case + "y", locals())
1302            self.assertEqual(result, expected)
1303            # Test the "friendly" method-name spelling, if one exists.
1304            if case in TestSubsets.case2method:
1305                method = getattr(x, TestSubsets.case2method[case])
1306                result = method(y)
1307                self.assertEqual(result, expected)
1308
1309            # Now do the same for the operands reversed.
1310            rcase = TestSubsets.reverse[case]
1311            result = eval("y" + rcase + "x", locals())
1312            self.assertEqual(result, expected)
1313            if rcase in TestSubsets.case2method:
1314                method = getattr(y, TestSubsets.case2method[rcase])
1315                result = method(x)
1316                self.assertEqual(result, expected)
1317#------------------------------------------------------------------------------
1318
1319class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
1320    left  = set()
1321    right = set()
1322    name  = "both empty"
1323    cases = "==", "<=", ">="
1324
1325#------------------------------------------------------------------------------
1326
1327class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
1328    left  = set([1, 2])
1329    right = set([1, 2])
1330    name  = "equal pair"
1331    cases = "==", "<=", ">="
1332
1333#------------------------------------------------------------------------------
1334
1335class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
1336    left  = set()
1337    right = set([1, 2])
1338    name  = "one empty, one non-empty"
1339    cases = "!=", "<", "<="
1340
1341#------------------------------------------------------------------------------
1342
1343class TestSubsetPartial(TestSubsets, unittest.TestCase):
1344    left  = set([1])
1345    right = set([1, 2])
1346    name  = "one a non-empty proper subset of other"
1347    cases = "!=", "<", "<="
1348
1349#------------------------------------------------------------------------------
1350
1351class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
1352    left  = set([1])
1353    right = set([2])
1354    name  = "neither empty, neither contains"
1355    cases = "!="
1356
1357#==============================================================================
1358
1359class TestOnlySetsInBinaryOps:
1360
1361    def test_eq_ne(self):
1362        # Unlike the others, this is testing that == and != *are* allowed.
1363        self.assertEqual(self.other == self.set, False)
1364        self.assertEqual(self.set == self.other, False)
1365        self.assertEqual(self.other != self.set, True)
1366        self.assertEqual(self.set != self.other, True)
1367
1368    def test_ge_gt_le_lt(self):
1369        self.assertRaises(TypeError, lambda: self.set < self.other)
1370        self.assertRaises(TypeError, lambda: self.set <= self.other)
1371        self.assertRaises(TypeError, lambda: self.set > self.other)
1372        self.assertRaises(TypeError, lambda: self.set >= self.other)
1373
1374        self.assertRaises(TypeError, lambda: self.other < self.set)
1375        self.assertRaises(TypeError, lambda: self.other <= self.set)
1376        self.assertRaises(TypeError, lambda: self.other > self.set)
1377        self.assertRaises(TypeError, lambda: self.other >= self.set)
1378
1379    def test_update_operator(self):
1380        try:
1381            self.set |= self.other
1382        except TypeError:
1383            pass
1384        else:
1385            self.fail("expected TypeError")
1386
1387    def test_update(self):
1388        if self.otherIsIterable:
1389            self.set.update(self.other)
1390        else:
1391            self.assertRaises(TypeError, self.set.update, self.other)
1392
1393    def test_union(self):
1394        self.assertRaises(TypeError, lambda: self.set | self.other)
1395        self.assertRaises(TypeError, lambda: self.other | self.set)
1396        if self.otherIsIterable:
1397            self.set.union(self.other)
1398        else:
1399            self.assertRaises(TypeError, self.set.union, self.other)
1400
1401    def test_intersection_update_operator(self):
1402        try:
1403            self.set &= self.other
1404        except TypeError:
1405            pass
1406        else:
1407            self.fail("expected TypeError")
1408
1409    def test_intersection_update(self):
1410        if self.otherIsIterable:
1411            self.set.intersection_update(self.other)
1412        else:
1413            self.assertRaises(TypeError,
1414                              self.set.intersection_update,
1415                              self.other)
1416
1417    def test_intersection(self):
1418        self.assertRaises(TypeError, lambda: self.set & self.other)
1419        self.assertRaises(TypeError, lambda: self.other & self.set)
1420        if self.otherIsIterable:
1421            self.set.intersection(self.other)
1422        else:
1423            self.assertRaises(TypeError, self.set.intersection, self.other)
1424
1425    def test_sym_difference_update_operator(self):
1426        try:
1427            self.set ^= self.other
1428        except TypeError:
1429            pass
1430        else:
1431            self.fail("expected TypeError")
1432
1433    def test_sym_difference_update(self):
1434        if self.otherIsIterable:
1435            self.set.symmetric_difference_update(self.other)
1436        else:
1437            self.assertRaises(TypeError,
1438                              self.set.symmetric_difference_update,
1439                              self.other)
1440
1441    def test_sym_difference(self):
1442        self.assertRaises(TypeError, lambda: self.set ^ self.other)
1443        self.assertRaises(TypeError, lambda: self.other ^ self.set)
1444        if self.otherIsIterable:
1445            self.set.symmetric_difference(self.other)
1446        else:
1447            self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
1448
1449    def test_difference_update_operator(self):
1450        try:
1451            self.set -= self.other
1452        except TypeError:
1453            pass
1454        else:
1455            self.fail("expected TypeError")
1456
1457    def test_difference_update(self):
1458        if self.otherIsIterable:
1459            self.set.difference_update(self.other)
1460        else:
1461            self.assertRaises(TypeError,
1462                              self.set.difference_update,
1463                              self.other)
1464
1465    def test_difference(self):
1466        self.assertRaises(TypeError, lambda: self.set - self.other)
1467        self.assertRaises(TypeError, lambda: self.other - self.set)
1468        if self.otherIsIterable:
1469            self.set.difference(self.other)
1470        else:
1471            self.assertRaises(TypeError, self.set.difference, self.other)
1472
1473#------------------------------------------------------------------------------
1474
1475class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
1476    def setUp(self):
1477        self.set   = set((1, 2, 3))
1478        self.other = 19
1479        self.otherIsIterable = False
1480
1481#------------------------------------------------------------------------------
1482
1483class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
1484    def setUp(self):
1485        self.set   = set((1, 2, 3))
1486        self.other = {1:2, 3:4}
1487        self.otherIsIterable = True
1488
1489#------------------------------------------------------------------------------
1490
1491class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
1492    def setUp(self):
1493        self.set   = set((1, 2, 3))
1494        self.other = operator.add
1495        self.otherIsIterable = False
1496
1497#------------------------------------------------------------------------------
1498
1499class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
1500    def setUp(self):
1501        self.set   = set((1, 2, 3))
1502        self.other = (2, 4, 6)
1503        self.otherIsIterable = True
1504
1505#------------------------------------------------------------------------------
1506
1507class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
1508    def setUp(self):
1509        self.set   = set((1, 2, 3))
1510        self.other = 'abc'
1511        self.otherIsIterable = True
1512
1513#------------------------------------------------------------------------------
1514
1515class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
1516    def setUp(self):
1517        def gen():
1518            for i in range(0, 10, 2):
1519                yield i
1520        self.set   = set((1, 2, 3))
1521        self.other = gen()
1522        self.otherIsIterable = True
1523
1524#==============================================================================
1525
1526class TestCopying:
1527
1528    def test_copy(self):
1529        dup = self.set.copy()
1530        dup_list = sorted(dup, key=repr)
1531        set_list = sorted(self.set, key=repr)
1532        self.assertEqual(len(dup_list), len(set_list))
1533        for i in range(len(dup_list)):
1534            self.assertTrue(dup_list[i] is set_list[i])
1535
1536    def test_deep_copy(self):
1537        dup = copy.deepcopy(self.set)
1538        ##print type(dup), repr(dup)
1539        dup_list = sorted(dup, key=repr)
1540        set_list = sorted(self.set, key=repr)
1541        self.assertEqual(len(dup_list), len(set_list))
1542        for i in range(len(dup_list)):
1543            self.assertEqual(dup_list[i], set_list[i])
1544
1545#------------------------------------------------------------------------------
1546
1547class TestCopyingEmpty(TestCopying, unittest.TestCase):
1548    def setUp(self):
1549        self.set = set()
1550
1551#------------------------------------------------------------------------------
1552
1553class TestCopyingSingleton(TestCopying, unittest.TestCase):
1554    def setUp(self):
1555        self.set = set(["hello"])
1556
1557#------------------------------------------------------------------------------
1558
1559class TestCopyingTriple(TestCopying, unittest.TestCase):
1560    def setUp(self):
1561        self.set = set(["zero", 0, None])
1562
1563#------------------------------------------------------------------------------
1564
1565class TestCopyingTuple(TestCopying, unittest.TestCase):
1566    def setUp(self):
1567        self.set = set([(1, 2)])
1568
1569#------------------------------------------------------------------------------
1570
1571class TestCopyingNested(TestCopying, unittest.TestCase):
1572    def setUp(self):
1573        self.set = set([((1, 2), (3, 4))])
1574
1575#==============================================================================
1576
1577class TestIdentities(unittest.TestCase):
1578    def setUp(self):
1579        self.a = set('abracadabra')
1580        self.b = set('alacazam')
1581
1582    def test_binopsVsSubsets(self):
1583        a, b = self.a, self.b
1584        self.assertTrue(a - b < a)
1585        self.assertTrue(b - a < b)
1586        self.assertTrue(a & b < a)
1587        self.assertTrue(a & b < b)
1588        self.assertTrue(a | b > a)
1589        self.assertTrue(a | b > b)
1590        self.assertTrue(a ^ b < a | b)
1591
1592    def test_commutativity(self):
1593        a, b = self.a, self.b
1594        self.assertEqual(a&b, b&a)
1595        self.assertEqual(a|b, b|a)
1596        self.assertEqual(a^b, b^a)
1597        if a != b:
1598            self.assertNotEqual(a-b, b-a)
1599
1600    def test_summations(self):
1601        # check that sums of parts equal the whole
1602        a, b = self.a, self.b
1603        self.assertEqual((a-b)|(a&b)|(b-a), a|b)
1604        self.assertEqual((a&b)|(a^b), a|b)
1605        self.assertEqual(a|(b-a), a|b)
1606        self.assertEqual((a-b)|b, a|b)
1607        self.assertEqual((a-b)|(a&b), a)
1608        self.assertEqual((b-a)|(a&b), b)
1609        self.assertEqual((a-b)|(b-a), a^b)
1610
1611    def test_exclusion(self):
1612        # check that inverse operations show non-overlap
1613        a, b, zero = self.a, self.b, set()
1614        self.assertEqual((a-b)&b, zero)
1615        self.assertEqual((b-a)&a, zero)
1616        self.assertEqual((a&b)&(a^b), zero)
1617
1618# Tests derived from test_itertools.py =======================================
1619
1620def R(seqn):
1621    'Regular generator'
1622    for i in seqn:
1623        yield i
1624
1625class G:
1626    'Sequence using __getitem__'
1627    def __init__(self, seqn):
1628        self.seqn = seqn
1629    def __getitem__(self, i):
1630        return self.seqn[i]
1631
1632class I:
1633    'Sequence using iterator protocol'
1634    def __init__(self, seqn):
1635        self.seqn = seqn
1636        self.i = 0
1637    def __iter__(self):
1638        return self
1639    def __next__(self):
1640        if self.i >= len(self.seqn): raise StopIteration
1641        v = self.seqn[self.i]
1642        self.i += 1
1643        return v
1644
1645class Ig:
1646    'Sequence using iterator protocol defined with a generator'
1647    def __init__(self, seqn):
1648        self.seqn = seqn
1649        self.i = 0
1650    def __iter__(self):
1651        for val in self.seqn:
1652            yield val
1653
1654class X:
1655    'Missing __getitem__ and __iter__'
1656    def __init__(self, seqn):
1657        self.seqn = seqn
1658        self.i = 0
1659    def __next__(self):
1660        if self.i >= len(self.seqn): raise StopIteration
1661        v = self.seqn[self.i]
1662        self.i += 1
1663        return v
1664
1665class N:
1666    'Iterator missing __next__()'
1667    def __init__(self, seqn):
1668        self.seqn = seqn
1669        self.i = 0
1670    def __iter__(self):
1671        return self
1672
1673class E:
1674    'Test propagation of exceptions'
1675    def __init__(self, seqn):
1676        self.seqn = seqn
1677        self.i = 0
1678    def __iter__(self):
1679        return self
1680    def __next__(self):
1681        3 // 0
1682
1683class S:
1684    'Test immediate stop'
1685    def __init__(self, seqn):
1686        pass
1687    def __iter__(self):
1688        return self
1689    def __next__(self):
1690        raise StopIteration
1691
1692from itertools import chain
1693def L(seqn):
1694    'Test multiple tiers of iterators'
1695    return chain(map(lambda x:x, R(Ig(G(seqn)))))
1696
1697class TestVariousIteratorArgs(unittest.TestCase):
1698
1699    def test_constructor(self):
1700        for cons in (set, frozenset):
1701            for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
1702                for g in (G, I, Ig, S, L, R):
1703                    self.assertEqual(sorted(cons(g(s)), key=repr), sorted(g(s), key=repr))
1704                self.assertRaises(TypeError, cons , X(s))
1705                self.assertRaises(TypeError, cons , N(s))
1706                self.assertRaises(ZeroDivisionError, cons , E(s))
1707
1708    def test_inline_methods(self):
1709        s = set('november')
1710        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1711            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
1712                for g in (G, I, Ig, L, R):
1713                    expected = meth(data)
1714                    actual = meth(g(data))
1715                    if isinstance(expected, bool):
1716                        self.assertEqual(actual, expected)
1717                    else:
1718                        self.assertEqual(sorted(actual, key=repr), sorted(expected, key=repr))
1719                self.assertRaises(TypeError, meth, X(s))
1720                self.assertRaises(TypeError, meth, N(s))
1721                self.assertRaises(ZeroDivisionError, meth, E(s))
1722
1723    def test_inplace_methods(self):
1724        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1725            for methname in ('update', 'intersection_update',
1726                             'difference_update', 'symmetric_difference_update'):
1727                for g in (G, I, Ig, S, L, R):
1728                    s = set('january')
1729                    t = s.copy()
1730                    getattr(s, methname)(list(g(data)))
1731                    getattr(t, methname)(g(data))
1732                    self.assertEqual(sorted(s, key=repr), sorted(t, key=repr))
1733
1734                self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
1735                self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
1736                self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
1737
1738class bad_eq:
1739    def __eq__(self, other):
1740        if be_bad:
1741            set2.clear()
1742            raise ZeroDivisionError
1743        return self is other
1744    def __hash__(self):
1745        return 0
1746
1747class bad_dict_clear:
1748    def __eq__(self, other):
1749        if be_bad:
1750            dict2.clear()
1751        return self is other
1752    def __hash__(self):
1753        return 0
1754
1755class TestWeirdBugs(unittest.TestCase):
1756    def test_8420_set_merge(self):
1757        # This used to segfault
1758        global be_bad, set2, dict2
1759        be_bad = False
1760        set1 = {bad_eq()}
1761        set2 = {bad_eq() for i in range(75)}
1762        be_bad = True
1763        self.assertRaises(ZeroDivisionError, set1.update, set2)
1764
1765        be_bad = False
1766        set1 = {bad_dict_clear()}
1767        dict2 = {bad_dict_clear(): None}
1768        be_bad = True
1769        set1.symmetric_difference_update(dict2)
1770
1771    def test_iter_and_mutate(self):
1772        # Issue #24581
1773        s = set(range(100))
1774        s.clear()
1775        s.update(range(100))
1776        si = iter(s)
1777        s.clear()
1778        a = list(range(100))
1779        s.update(range(100))
1780        list(si)
1781
1782    def test_merge_and_mutate(self):
1783        class X:
1784            def __hash__(self):
1785                return hash(0)
1786            def __eq__(self, o):
1787                other.clear()
1788                return False
1789
1790        other = set()
1791        other = {X() for i in range(10)}
1792        s = {0}
1793        s.update(other)
1794
1795# Application tests (based on David Eppstein's graph recipes ====================================
1796
1797def powerset(U):
1798    """Generates all subsets of a set or sequence U."""
1799    U = iter(U)
1800    try:
1801        x = frozenset([next(U)])
1802        for S in powerset(U):
1803            yield S
1804            yield S | x
1805    except StopIteration:
1806        yield frozenset()
1807
1808def cube(n):
1809    """Graph of n-dimensional hypercube."""
1810    singletons = [frozenset([x]) for x in range(n)]
1811    return dict([(x, frozenset([x^s for s in singletons]))
1812                 for x in powerset(range(n))])
1813
1814def linegraph(G):
1815    """Graph, the vertices of which are edges of G,
1816    with two vertices being adjacent iff the corresponding
1817    edges share a vertex."""
1818    L = {}
1819    for x in G:
1820        for y in G[x]:
1821            nx = [frozenset([x,z]) for z in G[x] if z != y]
1822            ny = [frozenset([y,z]) for z in G[y] if z != x]
1823            L[frozenset([x,y])] = frozenset(nx+ny)
1824    return L
1825
1826def faces(G):
1827    'Return a set of faces in G.  Where a face is a set of vertices on that face'
1828    # currently limited to triangles,squares, and pentagons
1829    f = set()
1830    for v1, edges in G.items():
1831        for v2 in edges:
1832            for v3 in G[v2]:
1833                if v1 == v3:
1834                    continue
1835                if v1 in G[v3]:
1836                    f.add(frozenset([v1, v2, v3]))
1837                else:
1838                    for v4 in G[v3]:
1839                        if v4 == v2:
1840                            continue
1841                        if v1 in G[v4]:
1842                            f.add(frozenset([v1, v2, v3, v4]))
1843                        else:
1844                            for v5 in G[v4]:
1845                                if v5 == v3 or v5 == v2:
1846                                    continue
1847                                if v1 in G[v5]:
1848                                    f.add(frozenset([v1, v2, v3, v4, v5]))
1849    return f
1850
1851
1852class TestGraphs(unittest.TestCase):
1853
1854    def test_cube(self):
1855
1856        g = cube(3)                             # vert --> {v1, v2, v3}
1857        vertices1 = set(g)
1858        self.assertEqual(len(vertices1), 8)     # eight vertices
1859        for edge in g.values():
1860            self.assertEqual(len(edge), 3)      # each vertex connects to three edges
1861        vertices2 = set(v for edges in g.values() for v in edges)
1862        self.assertEqual(vertices1, vertices2)  # edge vertices in original set
1863
1864        cubefaces = faces(g)
1865        self.assertEqual(len(cubefaces), 6)     # six faces
1866        for face in cubefaces:
1867            self.assertEqual(len(face), 4)      # each face is a square
1868
1869    def test_cuboctahedron(self):
1870
1871        # http://en.wikipedia.org/wiki/Cuboctahedron
1872        # 8 triangular faces and 6 square faces
1873        # 12 identical vertices each connecting a triangle and square
1874
1875        g = cube(3)
1876        cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
1877        self.assertEqual(len(cuboctahedron), 12)# twelve vertices
1878
1879        vertices = set(cuboctahedron)
1880        for edges in cuboctahedron.values():
1881            self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
1882        othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
1883        self.assertEqual(vertices, othervertices)   # edge vertices in original set
1884
1885        cubofaces = faces(cuboctahedron)
1886        facesizes = collections.defaultdict(int)
1887        for face in cubofaces:
1888            facesizes[len(face)] += 1
1889        self.assertEqual(facesizes[3], 8)       # eight triangular faces
1890        self.assertEqual(facesizes[4], 6)       # six square faces
1891
1892        for vertex in cuboctahedron:
1893            edge = vertex                       # Cuboctahedron vertices are edges in Cube
1894            self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
1895            for cubevert in edge:
1896                self.assertIn(cubevert, g)
1897
1898
1899#==============================================================================
1900
1901if __name__ == "__main__":
1902    unittest.main()
1903