• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from test import support
3import gc
4import weakref
5import operator
6import copy
7import pickle
8from random import randrange, shuffle
9import warnings
10import collections
11import collections.abc
12import itertools
13
14class PassThru(Exception):
15    pass
16
17def check_pass_thru():
18    raise PassThru
19    yield 1
20
21class BadCmp:
22    def __hash__(self):
23        return 1
24    def __eq__(self, other):
25        raise RuntimeError
26
27class ReprWrapper:
28    'Used to test self-referential repr() calls'
29    def __repr__(self):
30        return repr(self.value)
31
32class HashCountingInt(int):
33    'int-like object that counts the number of times __hash__ is called'
34    def __init__(self, *args):
35        self.hash_count = 0
36    def __hash__(self):
37        self.hash_count += 1
38        return int.__hash__(self)
39
40class TestJointOps:
41    # Tests common to both set and frozenset
42
43    def setUp(self):
44        self.word = word = 'simsalabim'
45        self.otherword = 'madagascar'
46        self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
47        self.s = self.thetype(word)
48        self.d = dict.fromkeys(word)
49
50    def test_new_or_init(self):
51        self.assertRaises(TypeError, self.thetype, [], 2)
52        self.assertRaises(TypeError, set().__init__, a=1)
53
54    def test_uniquification(self):
55        actual = sorted(self.s)
56        expected = sorted(self.d)
57        self.assertEqual(actual, expected)
58        self.assertRaises(PassThru, self.thetype, check_pass_thru())
59        self.assertRaises(TypeError, self.thetype, [[]])
60
61    def test_len(self):
62        self.assertEqual(len(self.s), len(self.d))
63
64    def test_contains(self):
65        for c in self.letters:
66            self.assertEqual(c in self.s, c in self.d)
67        self.assertRaises(TypeError, self.s.__contains__, [[]])
68        s = self.thetype([frozenset(self.letters)])
69        self.assertIn(self.thetype(self.letters), s)
70
71    def test_union(self):
72        u = self.s.union(self.otherword)
73        for c in self.letters:
74            self.assertEqual(c in u, c in self.d or c in self.otherword)
75        self.assertEqual(self.s, self.thetype(self.word))
76        self.assertEqual(type(u), self.basetype)
77        self.assertRaises(PassThru, self.s.union, check_pass_thru())
78        self.assertRaises(TypeError, self.s.union, [[]])
79        for C in set, frozenset, dict.fromkeys, str, list, tuple:
80            self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
81            self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
82            self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
83            self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
84            self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
85
86        # Issue #6573
87        x = self.thetype()
88        self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
89
90    def test_or(self):
91        i = self.s.union(self.otherword)
92        self.assertEqual(self.s | set(self.otherword), i)
93        self.assertEqual(self.s | frozenset(self.otherword), i)
94        try:
95            self.s | self.otherword
96        except TypeError:
97            pass
98        else:
99            self.fail("s|t did not screen-out general iterables")
100
101    def test_intersection(self):
102        i = self.s.intersection(self.otherword)
103        for c in self.letters:
104            self.assertEqual(c in i, c in self.d and c in self.otherword)
105        self.assertEqual(self.s, self.thetype(self.word))
106        self.assertEqual(type(i), self.basetype)
107        self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
108        for C in set, frozenset, dict.fromkeys, str, list, tuple:
109            self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
110            self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
111            self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
112            self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
113            self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
114        s = self.thetype('abcba')
115        z = s.intersection()
116        if self.thetype == frozenset():
117            self.assertEqual(id(s), id(z))
118        else:
119            self.assertNotEqual(id(s), id(z))
120
121    def test_isdisjoint(self):
122        def f(s1, s2):
123            'Pure python equivalent of isdisjoint()'
124            return not set(s1).intersection(s2)
125        for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
126            s1 = self.thetype(larg)
127            for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
128                for C in set, frozenset, dict.fromkeys, str, list, tuple:
129                    s2 = C(rarg)
130                    actual = s1.isdisjoint(s2)
131                    expected = f(s1, s2)
132                    self.assertEqual(actual, expected)
133                    self.assertTrue(actual is True or actual is False)
134
135    def test_and(self):
136        i = self.s.intersection(self.otherword)
137        self.assertEqual(self.s & set(self.otherword), i)
138        self.assertEqual(self.s & frozenset(self.otherword), i)
139        try:
140            self.s & self.otherword
141        except TypeError:
142            pass
143        else:
144            self.fail("s&t did not screen-out general iterables")
145
146    def test_difference(self):
147        i = self.s.difference(self.otherword)
148        for c in self.letters:
149            self.assertEqual(c in i, c in self.d and c not in self.otherword)
150        self.assertEqual(self.s, self.thetype(self.word))
151        self.assertEqual(type(i), self.basetype)
152        self.assertRaises(PassThru, self.s.difference, check_pass_thru())
153        self.assertRaises(TypeError, self.s.difference, [[]])
154        for C in set, frozenset, dict.fromkeys, str, list, tuple:
155            self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
156            self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
157            self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
158            self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
159            self.assertEqual(self.thetype('abcba').difference(), set('abc'))
160            self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
161
162    def test_sub(self):
163        i = self.s.difference(self.otherword)
164        self.assertEqual(self.s - set(self.otherword), i)
165        self.assertEqual(self.s - frozenset(self.otherword), i)
166        try:
167            self.s - self.otherword
168        except TypeError:
169            pass
170        else:
171            self.fail("s-t did not screen-out general iterables")
172
173    def test_symmetric_difference(self):
174        i = self.s.symmetric_difference(self.otherword)
175        for c in self.letters:
176            self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
177        self.assertEqual(self.s, self.thetype(self.word))
178        self.assertEqual(type(i), self.basetype)
179        self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
180        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
181        for C in set, frozenset, dict.fromkeys, str, list, tuple:
182            self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
183            self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
184            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
185            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
186
187    def test_xor(self):
188        i = self.s.symmetric_difference(self.otherword)
189        self.assertEqual(self.s ^ set(self.otherword), i)
190        self.assertEqual(self.s ^ frozenset(self.otherword), i)
191        try:
192            self.s ^ self.otherword
193        except TypeError:
194            pass
195        else:
196            self.fail("s^t did not screen-out general iterables")
197
198    def test_equality(self):
199        self.assertEqual(self.s, set(self.word))
200        self.assertEqual(self.s, frozenset(self.word))
201        self.assertEqual(self.s == self.word, False)
202        self.assertNotEqual(self.s, set(self.otherword))
203        self.assertNotEqual(self.s, frozenset(self.otherword))
204        self.assertEqual(self.s != self.word, True)
205
206    def test_setOfFrozensets(self):
207        t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
208        s = self.thetype(t)
209        self.assertEqual(len(s), 3)
210
211    def test_sub_and_super(self):
212        p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
213        self.assertTrue(p < q)
214        self.assertTrue(p <= q)
215        self.assertTrue(q <= q)
216        self.assertTrue(q > p)
217        self.assertTrue(q >= p)
218        self.assertFalse(q < r)
219        self.assertFalse(q <= r)
220        self.assertFalse(q > r)
221        self.assertFalse(q >= r)
222        self.assertTrue(set('a').issubset('abc'))
223        self.assertTrue(set('abc').issuperset('a'))
224        self.assertFalse(set('a').issubset('cbs'))
225        self.assertFalse(set('cbs').issuperset('a'))
226
227    def test_pickling(self):
228        for i in range(pickle.HIGHEST_PROTOCOL + 1):
229            p = pickle.dumps(self.s, i)
230            dup = pickle.loads(p)
231            self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
232            if type(self.s) not in (set, frozenset):
233                self.s.x = 10
234                p = pickle.dumps(self.s, i)
235                dup = pickle.loads(p)
236                self.assertEqual(self.s.x, dup.x)
237
238    def test_iterator_pickling(self):
239        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
240            itorg = iter(self.s)
241            data = self.thetype(self.s)
242            d = pickle.dumps(itorg, proto)
243            it = pickle.loads(d)
244            # Set iterators unpickle as list iterators due to the
245            # undefined order of set items.
246            # self.assertEqual(type(itorg), type(it))
247            self.assertIsInstance(it, collections.abc.Iterator)
248            self.assertEqual(self.thetype(it), data)
249
250            it = pickle.loads(d)
251            try:
252                drop = next(it)
253            except StopIteration:
254                continue
255            d = pickle.dumps(it, proto)
256            it = pickle.loads(d)
257            self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
258
259    def test_deepcopy(self):
260        class Tracer:
261            def __init__(self, value):
262                self.value = value
263            def __hash__(self):
264                return self.value
265            def __deepcopy__(self, memo=None):
266                return Tracer(self.value + 1)
267        t = Tracer(10)
268        s = self.thetype([t])
269        dup = copy.deepcopy(s)
270        self.assertNotEqual(id(s), id(dup))
271        for elem in dup:
272            newt = elem
273        self.assertNotEqual(id(t), id(newt))
274        self.assertEqual(t.value + 1, newt.value)
275
276    def test_gc(self):
277        # Create a nest of cycles to exercise overall ref count check
278        class A:
279            pass
280        s = set(A() for i in range(1000))
281        for elem in s:
282            elem.cycle = s
283            elem.sub = elem
284            elem.set = set([elem])
285
286    def test_subclass_with_custom_hash(self):
287        # Bug #1257731
288        class H(self.thetype):
289            def __hash__(self):
290                return int(id(self) & 0x7fffffff)
291        s=H()
292        f=set()
293        f.add(s)
294        self.assertIn(s, f)
295        f.remove(s)
296        f.add(s)
297        f.discard(s)
298
299    def test_badcmp(self):
300        s = self.thetype([BadCmp()])
301        # Detect comparison errors during insertion and lookup
302        self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()])
303        self.assertRaises(RuntimeError, s.__contains__, BadCmp())
304        # Detect errors during mutating operations
305        if hasattr(s, 'add'):
306            self.assertRaises(RuntimeError, s.add, BadCmp())
307            self.assertRaises(RuntimeError, s.discard, BadCmp())
308            self.assertRaises(RuntimeError, s.remove, BadCmp())
309
310    def test_cyclical_repr(self):
311        w = ReprWrapper()
312        s = self.thetype([w])
313        w.value = s
314        if self.thetype == set:
315            self.assertEqual(repr(s), '{set(...)}')
316        else:
317            name = repr(s).partition('(')[0]    # strip class name
318            self.assertEqual(repr(s), '%s({%s(...)})' % (name, name))
319
320    def test_cyclical_print(self):
321        w = ReprWrapper()
322        s = self.thetype([w])
323        w.value = s
324        fo = open(support.TESTFN, "w")
325        try:
326            fo.write(str(s))
327            fo.close()
328            fo = open(support.TESTFN, "r")
329            self.assertEqual(fo.read(), repr(s))
330        finally:
331            fo.close()
332            support.unlink(support.TESTFN)
333
334    def test_do_not_rehash_dict_keys(self):
335        n = 10
336        d = dict.fromkeys(map(HashCountingInt, range(n)))
337        self.assertEqual(sum(elem.hash_count for elem in d), n)
338        s = self.thetype(d)
339        self.assertEqual(sum(elem.hash_count for elem in d), n)
340        s.difference(d)
341        self.assertEqual(sum(elem.hash_count for elem in d), n)
342        if hasattr(s, 'symmetric_difference_update'):
343            s.symmetric_difference_update(d)
344        self.assertEqual(sum(elem.hash_count for elem in d), n)
345        d2 = dict.fromkeys(set(d))
346        self.assertEqual(sum(elem.hash_count for elem in d), n)
347        d3 = dict.fromkeys(frozenset(d))
348        self.assertEqual(sum(elem.hash_count for elem in d), n)
349        d3 = dict.fromkeys(frozenset(d), 123)
350        self.assertEqual(sum(elem.hash_count for elem in d), n)
351        self.assertEqual(d3, dict.fromkeys(d, 123))
352
353    def test_container_iterator(self):
354        # Bug #3680: tp_traverse was not implemented for set iterator object
355        class C(object):
356            pass
357        obj = C()
358        ref = weakref.ref(obj)
359        container = set([obj, 1])
360        obj.x = iter(container)
361        del obj, container
362        gc.collect()
363        self.assertTrue(ref() is None, "Cycle was not collected")
364
365    def test_free_after_iterating(self):
366        support.check_free_after_iterating(self, iter, self.thetype)
367
368class TestSet(TestJointOps, unittest.TestCase):
369    thetype = set
370    basetype = set
371
372    def test_init(self):
373        s = self.thetype()
374        s.__init__(self.word)
375        self.assertEqual(s, set(self.word))
376        s.__init__(self.otherword)
377        self.assertEqual(s, set(self.otherword))
378        self.assertRaises(TypeError, s.__init__, s, 2);
379        self.assertRaises(TypeError, s.__init__, 1);
380
381    def test_constructor_identity(self):
382        s = self.thetype(range(3))
383        t = self.thetype(s)
384        self.assertNotEqual(id(s), id(t))
385
386    def test_set_literal(self):
387        s = set([1,2,3])
388        t = {1,2,3}
389        self.assertEqual(s, t)
390
391    def test_set_literal_insertion_order(self):
392        # SF Issue #26020 -- Expect left to right insertion
393        s = {1, 1.0, True}
394        self.assertEqual(len(s), 1)
395        stored_value = s.pop()
396        self.assertEqual(type(stored_value), int)
397
398    def test_set_literal_evaluation_order(self):
399        # Expect left to right expression evaluation
400        events = []
401        def record(obj):
402            events.append(obj)
403        s = {record(1), record(2), record(3)}
404        self.assertEqual(events, [1, 2, 3])
405
406    def test_hash(self):
407        self.assertRaises(TypeError, hash, self.s)
408
409    def test_clear(self):
410        self.s.clear()
411        self.assertEqual(self.s, set())
412        self.assertEqual(len(self.s), 0)
413
414    def test_copy(self):
415        dup = self.s.copy()
416        self.assertEqual(self.s, dup)
417        self.assertNotEqual(id(self.s), id(dup))
418        self.assertEqual(type(dup), self.basetype)
419
420    def test_add(self):
421        self.s.add('Q')
422        self.assertIn('Q', self.s)
423        dup = self.s.copy()
424        self.s.add('Q')
425        self.assertEqual(self.s, dup)
426        self.assertRaises(TypeError, self.s.add, [])
427
428    def test_remove(self):
429        self.s.remove('a')
430        self.assertNotIn('a', self.s)
431        self.assertRaises(KeyError, self.s.remove, 'Q')
432        self.assertRaises(TypeError, self.s.remove, [])
433        s = self.thetype([frozenset(self.word)])
434        self.assertIn(self.thetype(self.word), s)
435        s.remove(self.thetype(self.word))
436        self.assertNotIn(self.thetype(self.word), s)
437        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
438
439    def test_remove_keyerror_unpacking(self):
440        # bug:  www.python.org/sf/1576657
441        for v1 in ['Q', (1,)]:
442            try:
443                self.s.remove(v1)
444            except KeyError as e:
445                v2 = e.args[0]
446                self.assertEqual(v1, v2)
447            else:
448                self.fail()
449
450    def test_remove_keyerror_set(self):
451        key = self.thetype([3, 4])
452        try:
453            self.s.remove(key)
454        except KeyError as e:
455            self.assertTrue(e.args[0] is key,
456                         "KeyError should be {0}, not {1}".format(key,
457                                                                  e.args[0]))
458        else:
459            self.fail()
460
461    def test_discard(self):
462        self.s.discard('a')
463        self.assertNotIn('a', self.s)
464        self.s.discard('Q')
465        self.assertRaises(TypeError, self.s.discard, [])
466        s = self.thetype([frozenset(self.word)])
467        self.assertIn(self.thetype(self.word), s)
468        s.discard(self.thetype(self.word))
469        self.assertNotIn(self.thetype(self.word), s)
470        s.discard(self.thetype(self.word))
471
472    def test_pop(self):
473        for i in range(len(self.s)):
474            elem = self.s.pop()
475            self.assertNotIn(elem, self.s)
476        self.assertRaises(KeyError, self.s.pop)
477
478    def test_update(self):
479        retval = self.s.update(self.otherword)
480        self.assertEqual(retval, None)
481        for c in (self.word + self.otherword):
482            self.assertIn(c, self.s)
483        self.assertRaises(PassThru, self.s.update, check_pass_thru())
484        self.assertRaises(TypeError, self.s.update, [[]])
485        for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
486            for C in set, frozenset, dict.fromkeys, str, list, tuple:
487                s = self.thetype('abcba')
488                self.assertEqual(s.update(C(p)), None)
489                self.assertEqual(s, set(q))
490        for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
491            q = 'ahi'
492            for C in set, frozenset, dict.fromkeys, str, list, tuple:
493                s = self.thetype('abcba')
494                self.assertEqual(s.update(C(p), C(q)), None)
495                self.assertEqual(s, set(s) | set(p) | set(q))
496
497    def test_ior(self):
498        self.s |= set(self.otherword)
499        for c in (self.word + self.otherword):
500            self.assertIn(c, self.s)
501
502    def test_intersection_update(self):
503        retval = self.s.intersection_update(self.otherword)
504        self.assertEqual(retval, None)
505        for c in (self.word + self.otherword):
506            if c in self.otherword and c in self.word:
507                self.assertIn(c, self.s)
508            else:
509                self.assertNotIn(c, self.s)
510        self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
511        self.assertRaises(TypeError, self.s.intersection_update, [[]])
512        for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
513            for C in set, frozenset, dict.fromkeys, str, list, tuple:
514                s = self.thetype('abcba')
515                self.assertEqual(s.intersection_update(C(p)), None)
516                self.assertEqual(s, set(q))
517                ss = 'abcba'
518                s = self.thetype(ss)
519                t = 'cbc'
520                self.assertEqual(s.intersection_update(C(p), C(t)), None)
521                self.assertEqual(s, set('abcba')&set(p)&set(t))
522
523    def test_iand(self):
524        self.s &= set(self.otherword)
525        for c in (self.word + self.otherword):
526            if c in self.otherword and c in self.word:
527                self.assertIn(c, self.s)
528            else:
529                self.assertNotIn(c, self.s)
530
531    def test_difference_update(self):
532        retval = self.s.difference_update(self.otherword)
533        self.assertEqual(retval, None)
534        for c in (self.word + self.otherword):
535            if c in self.word and c not in self.otherword:
536                self.assertIn(c, self.s)
537            else:
538                self.assertNotIn(c, self.s)
539        self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
540        self.assertRaises(TypeError, self.s.difference_update, [[]])
541        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
542        for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
543            for C in set, frozenset, dict.fromkeys, str, list, tuple:
544                s = self.thetype('abcba')
545                self.assertEqual(s.difference_update(C(p)), None)
546                self.assertEqual(s, set(q))
547
548                s = self.thetype('abcdefghih')
549                s.difference_update()
550                self.assertEqual(s, self.thetype('abcdefghih'))
551
552                s = self.thetype('abcdefghih')
553                s.difference_update(C('aba'))
554                self.assertEqual(s, self.thetype('cdefghih'))
555
556                s = self.thetype('abcdefghih')
557                s.difference_update(C('cdc'), C('aba'))
558                self.assertEqual(s, self.thetype('efghih'))
559
560    def test_isub(self):
561        self.s -= set(self.otherword)
562        for c in (self.word + self.otherword):
563            if c in self.word and c not in self.otherword:
564                self.assertIn(c, self.s)
565            else:
566                self.assertNotIn(c, self.s)
567
568    def test_symmetric_difference_update(self):
569        retval = self.s.symmetric_difference_update(self.otherword)
570        self.assertEqual(retval, None)
571        for c in (self.word + self.otherword):
572            if (c in self.word) ^ (c in self.otherword):
573                self.assertIn(c, self.s)
574            else:
575                self.assertNotIn(c, self.s)
576        self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
577        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
578        for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
579            for C in set, frozenset, dict.fromkeys, str, list, tuple:
580                s = self.thetype('abcba')
581                self.assertEqual(s.symmetric_difference_update(C(p)), None)
582                self.assertEqual(s, set(q))
583
584    def test_ixor(self):
585        self.s ^= set(self.otherword)
586        for c in (self.word + self.otherword):
587            if (c in self.word) ^ (c in self.otherword):
588                self.assertIn(c, self.s)
589            else:
590                self.assertNotIn(c, self.s)
591
592    def test_inplace_on_self(self):
593        t = self.s.copy()
594        t |= t
595        self.assertEqual(t, self.s)
596        t &= t
597        self.assertEqual(t, self.s)
598        t -= t
599        self.assertEqual(t, self.thetype())
600        t = self.s.copy()
601        t ^= t
602        self.assertEqual(t, self.thetype())
603
604    def test_weakref(self):
605        s = self.thetype('gallahad')
606        p = weakref.proxy(s)
607        self.assertEqual(str(p), str(s))
608        s = None
609        self.assertRaises(ReferenceError, str, p)
610
611    def test_rich_compare(self):
612        class TestRichSetCompare:
613            def __gt__(self, some_set):
614                self.gt_called = True
615                return False
616            def __lt__(self, some_set):
617                self.lt_called = True
618                return False
619            def __ge__(self, some_set):
620                self.ge_called = True
621                return False
622            def __le__(self, some_set):
623                self.le_called = True
624                return False
625
626        # This first tries the builtin rich set comparison, which doesn't know
627        # how to handle the custom object. Upon returning NotImplemented, the
628        # corresponding comparison on the right object is invoked.
629        myset = {1, 2, 3}
630
631        myobj = TestRichSetCompare()
632        myset < myobj
633        self.assertTrue(myobj.gt_called)
634
635        myobj = TestRichSetCompare()
636        myset > myobj
637        self.assertTrue(myobj.lt_called)
638
639        myobj = TestRichSetCompare()
640        myset <= myobj
641        self.assertTrue(myobj.ge_called)
642
643        myobj = TestRichSetCompare()
644        myset >= myobj
645        self.assertTrue(myobj.le_called)
646
647    @unittest.skipUnless(hasattr(set, "test_c_api"),
648                         'C API test only available in a debug build')
649    def test_c_api(self):
650        self.assertEqual(set().test_c_api(), True)
651
652class SetSubclass(set):
653    pass
654
655class TestSetSubclass(TestSet):
656    thetype = SetSubclass
657    basetype = set
658
659class SetSubclassWithKeywordArgs(set):
660    def __init__(self, iterable=[], newarg=None):
661        set.__init__(self, iterable)
662
663class TestSetSubclassWithKeywordArgs(TestSet):
664
665    def test_keywords_in_subclass(self):
666        'SF bug #1486663 -- this used to erroneously raise a TypeError'
667        SetSubclassWithKeywordArgs(newarg=1)
668
669class TestFrozenSet(TestJointOps, unittest.TestCase):
670    thetype = frozenset
671    basetype = frozenset
672
673    def test_init(self):
674        s = self.thetype(self.word)
675        s.__init__(self.otherword)
676        self.assertEqual(s, set(self.word))
677
678    def test_singleton_empty_frozenset(self):
679        f = frozenset()
680        efs = [frozenset(), frozenset([]), frozenset(()), frozenset(''),
681               frozenset(), frozenset([]), frozenset(()), frozenset(''),
682               frozenset(range(0)), frozenset(frozenset()),
683               frozenset(f), f]
684        # All of the empty frozensets should have just one id()
685        self.assertEqual(len(set(map(id, efs))), 1)
686
687    def test_constructor_identity(self):
688        s = self.thetype(range(3))
689        t = self.thetype(s)
690        self.assertEqual(id(s), id(t))
691
692    def test_hash(self):
693        self.assertEqual(hash(self.thetype('abcdeb')),
694                         hash(self.thetype('ebecda')))
695
696        # make sure that all permutations give the same hash value
697        n = 100
698        seq = [randrange(n) for i in range(n)]
699        results = set()
700        for i in range(200):
701            shuffle(seq)
702            results.add(hash(self.thetype(seq)))
703        self.assertEqual(len(results), 1)
704
705    def test_copy(self):
706        dup = self.s.copy()
707        self.assertEqual(id(self.s), id(dup))
708
709    def test_frozen_as_dictkey(self):
710        seq = list(range(10)) + list('abcdefg') + ['apple']
711        key1 = self.thetype(seq)
712        key2 = self.thetype(reversed(seq))
713        self.assertEqual(key1, key2)
714        self.assertNotEqual(id(key1), id(key2))
715        d = {}
716        d[key1] = 42
717        self.assertEqual(d[key2], 42)
718
719    def test_hash_caching(self):
720        f = self.thetype('abcdcda')
721        self.assertEqual(hash(f), hash(f))
722
723    def test_hash_effectiveness(self):
724        n = 13
725        hashvalues = set()
726        addhashvalue = hashvalues.add
727        elemmasks = [(i+1, 1<<i) for i in range(n)]
728        for i in range(2**n):
729            addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
730        self.assertEqual(len(hashvalues), 2**n)
731
732        def zf_range(n):
733            # https://en.wikipedia.org/wiki/Set-theoretic_definition_of_natural_numbers
734            nums = [frozenset()]
735            for i in range(n-1):
736                num = frozenset(nums)
737                nums.append(num)
738            return nums[:n]
739
740        def powerset(s):
741            for i in range(len(s)+1):
742                yield from map(frozenset, itertools.combinations(s, i))
743
744        for n in range(18):
745            t = 2 ** n
746            mask = t - 1
747            for nums in (range, zf_range):
748                u = len({h & mask for h in map(hash, powerset(nums(n)))})
749                self.assertGreater(4*u, t)
750
751class FrozenSetSubclass(frozenset):
752    pass
753
754class TestFrozenSetSubclass(TestFrozenSet):
755    thetype = FrozenSetSubclass
756    basetype = frozenset
757
758    def test_constructor_identity(self):
759        s = self.thetype(range(3))
760        t = self.thetype(s)
761        self.assertNotEqual(id(s), id(t))
762
763    def test_copy(self):
764        dup = self.s.copy()
765        self.assertNotEqual(id(self.s), id(dup))
766
767    def test_nested_empty_constructor(self):
768        s = self.thetype()
769        t = self.thetype(s)
770        self.assertEqual(s, t)
771
772    def test_singleton_empty_frozenset(self):
773        Frozenset = self.thetype
774        f = frozenset()
775        F = Frozenset()
776        efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
777               Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
778               Frozenset(range(0)), Frozenset(Frozenset()),
779               Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)]
780        # All empty frozenset subclass instances should have different ids
781        self.assertEqual(len(set(map(id, efs))), len(efs))
782
783# Tests taken from test_sets.py =============================================
784
785empty_set = set()
786
787#==============================================================================
788
789class TestBasicOps:
790
791    def test_repr(self):
792        if self.repr is not None:
793            self.assertEqual(repr(self.set), self.repr)
794
795    def check_repr_against_values(self):
796        text = repr(self.set)
797        self.assertTrue(text.startswith('{'))
798        self.assertTrue(text.endswith('}'))
799
800        result = text[1:-1].split(', ')
801        result.sort()
802        sorted_repr_values = [repr(value) for value in self.values]
803        sorted_repr_values.sort()
804        self.assertEqual(result, sorted_repr_values)
805
806    def test_print(self):
807        try:
808            fo = open(support.TESTFN, "w")
809            fo.write(str(self.set))
810            fo.close()
811            fo = open(support.TESTFN, "r")
812            self.assertEqual(fo.read(), repr(self.set))
813        finally:
814            fo.close()
815            support.unlink(support.TESTFN)
816
817    def test_length(self):
818        self.assertEqual(len(self.set), self.length)
819
820    def test_self_equality(self):
821        self.assertEqual(self.set, self.set)
822
823    def test_equivalent_equality(self):
824        self.assertEqual(self.set, self.dup)
825
826    def test_copy(self):
827        self.assertEqual(self.set.copy(), self.dup)
828
829    def test_self_union(self):
830        result = self.set | self.set
831        self.assertEqual(result, self.dup)
832
833    def test_empty_union(self):
834        result = self.set | empty_set
835        self.assertEqual(result, self.dup)
836
837    def test_union_empty(self):
838        result = empty_set | self.set
839        self.assertEqual(result, self.dup)
840
841    def test_self_intersection(self):
842        result = self.set & self.set
843        self.assertEqual(result, self.dup)
844
845    def test_empty_intersection(self):
846        result = self.set & empty_set
847        self.assertEqual(result, empty_set)
848
849    def test_intersection_empty(self):
850        result = empty_set & self.set
851        self.assertEqual(result, empty_set)
852
853    def test_self_isdisjoint(self):
854        result = self.set.isdisjoint(self.set)
855        self.assertEqual(result, not self.set)
856
857    def test_empty_isdisjoint(self):
858        result = self.set.isdisjoint(empty_set)
859        self.assertEqual(result, True)
860
861    def test_isdisjoint_empty(self):
862        result = empty_set.isdisjoint(self.set)
863        self.assertEqual(result, True)
864
865    def test_self_symmetric_difference(self):
866        result = self.set ^ self.set
867        self.assertEqual(result, empty_set)
868
869    def test_empty_symmetric_difference(self):
870        result = self.set ^ empty_set
871        self.assertEqual(result, self.set)
872
873    def test_self_difference(self):
874        result = self.set - self.set
875        self.assertEqual(result, empty_set)
876
877    def test_empty_difference(self):
878        result = self.set - empty_set
879        self.assertEqual(result, self.dup)
880
881    def test_empty_difference_rev(self):
882        result = empty_set - self.set
883        self.assertEqual(result, empty_set)
884
885    def test_iteration(self):
886        for v in self.set:
887            self.assertIn(v, self.values)
888        setiter = iter(self.set)
889        self.assertEqual(setiter.__length_hint__(), len(self.set))
890
891    def test_pickling(self):
892        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
893            p = pickle.dumps(self.set, proto)
894            copy = pickle.loads(p)
895            self.assertEqual(self.set, copy,
896                             "%s != %s" % (self.set, copy))
897
898    def test_issue_37219(self):
899        with self.assertRaises(TypeError):
900            set().difference(123)
901        with self.assertRaises(TypeError):
902            set().difference_update(123)
903
904#------------------------------------------------------------------------------
905
906class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
907    def setUp(self):
908        self.case   = "empty set"
909        self.values = []
910        self.set    = set(self.values)
911        self.dup    = set(self.values)
912        self.length = 0
913        self.repr   = "set()"
914
915#------------------------------------------------------------------------------
916
917class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
918    def setUp(self):
919        self.case   = "unit set (number)"
920        self.values = [3]
921        self.set    = set(self.values)
922        self.dup    = set(self.values)
923        self.length = 1
924        self.repr   = "{3}"
925
926    def test_in(self):
927        self.assertIn(3, self.set)
928
929    def test_not_in(self):
930        self.assertNotIn(2, self.set)
931
932#------------------------------------------------------------------------------
933
934class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
935    def setUp(self):
936        self.case   = "unit set (tuple)"
937        self.values = [(0, "zero")]
938        self.set    = set(self.values)
939        self.dup    = set(self.values)
940        self.length = 1
941        self.repr   = "{(0, 'zero')}"
942
943    def test_in(self):
944        self.assertIn((0, "zero"), self.set)
945
946    def test_not_in(self):
947        self.assertNotIn(9, self.set)
948
949#------------------------------------------------------------------------------
950
951class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
952    def setUp(self):
953        self.case   = "triple set"
954        self.values = [0, "zero", operator.add]
955        self.set    = set(self.values)
956        self.dup    = set(self.values)
957        self.length = 3
958        self.repr   = None
959
960#------------------------------------------------------------------------------
961
962class TestBasicOpsString(TestBasicOps, unittest.TestCase):
963    def setUp(self):
964        self.case   = "string set"
965        self.values = ["a", "b", "c"]
966        self.set    = set(self.values)
967        self.dup    = set(self.values)
968        self.length = 3
969
970    def test_repr(self):
971        self.check_repr_against_values()
972
973#------------------------------------------------------------------------------
974
975class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
976    def setUp(self):
977        self.case   = "bytes set"
978        self.values = [b"a", b"b", b"c"]
979        self.set    = set(self.values)
980        self.dup    = set(self.values)
981        self.length = 3
982
983    def test_repr(self):
984        self.check_repr_against_values()
985
986#------------------------------------------------------------------------------
987
988class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
989    def setUp(self):
990        self._warning_filters = support.check_warnings()
991        self._warning_filters.__enter__()
992        warnings.simplefilter('ignore', BytesWarning)
993        self.case   = "string and bytes set"
994        self.values = ["a", "b", b"a", b"b"]
995        self.set    = set(self.values)
996        self.dup    = set(self.values)
997        self.length = 4
998
999    def tearDown(self):
1000        self._warning_filters.__exit__(None, None, None)
1001
1002    def test_repr(self):
1003        self.check_repr_against_values()
1004
1005#==============================================================================
1006
1007def baditer():
1008    raise TypeError
1009    yield True
1010
1011def gooditer():
1012    yield True
1013
1014class TestExceptionPropagation(unittest.TestCase):
1015    """SF 628246:  Set constructor should not trap iterator TypeErrors"""
1016
1017    def test_instanceWithException(self):
1018        self.assertRaises(TypeError, set, baditer())
1019
1020    def test_instancesWithoutException(self):
1021        # All of these iterables should load without exception.
1022        set([1,2,3])
1023        set((1,2,3))
1024        set({'one':1, 'two':2, 'three':3})
1025        set(range(3))
1026        set('abc')
1027        set(gooditer())
1028
1029    def test_changingSizeWhileIterating(self):
1030        s = set([1,2,3])
1031        try:
1032            for i in s:
1033                s.update([4])
1034        except RuntimeError:
1035            pass
1036        else:
1037            self.fail("no exception when changing size during iteration")
1038
1039#==============================================================================
1040
1041class TestSetOfSets(unittest.TestCase):
1042    def test_constructor(self):
1043        inner = frozenset([1])
1044        outer = set([inner])
1045        element = outer.pop()
1046        self.assertEqual(type(element), frozenset)
1047        outer.add(inner)        # Rebuild set of sets with .add method
1048        outer.remove(inner)
1049        self.assertEqual(outer, set())   # Verify that remove worked
1050        outer.discard(inner)    # Absence of KeyError indicates working fine
1051
1052#==============================================================================
1053
1054class TestBinaryOps(unittest.TestCase):
1055    def setUp(self):
1056        self.set = set((2, 4, 6))
1057
1058    def test_eq(self):              # SF bug 643115
1059        self.assertEqual(self.set, set({2:1,4:3,6:5}))
1060
1061    def test_union_subset(self):
1062        result = self.set | set([2])
1063        self.assertEqual(result, set((2, 4, 6)))
1064
1065    def test_union_superset(self):
1066        result = self.set | set([2, 4, 6, 8])
1067        self.assertEqual(result, set([2, 4, 6, 8]))
1068
1069    def test_union_overlap(self):
1070        result = self.set | set([3, 4, 5])
1071        self.assertEqual(result, set([2, 3, 4, 5, 6]))
1072
1073    def test_union_non_overlap(self):
1074        result = self.set | set([8])
1075        self.assertEqual(result, set([2, 4, 6, 8]))
1076
1077    def test_intersection_subset(self):
1078        result = self.set & set((2, 4))
1079        self.assertEqual(result, set((2, 4)))
1080
1081    def test_intersection_superset(self):
1082        result = self.set & set([2, 4, 6, 8])
1083        self.assertEqual(result, set([2, 4, 6]))
1084
1085    def test_intersection_overlap(self):
1086        result = self.set & set([3, 4, 5])
1087        self.assertEqual(result, set([4]))
1088
1089    def test_intersection_non_overlap(self):
1090        result = self.set & set([8])
1091        self.assertEqual(result, empty_set)
1092
1093    def test_isdisjoint_subset(self):
1094        result = self.set.isdisjoint(set((2, 4)))
1095        self.assertEqual(result, False)
1096
1097    def test_isdisjoint_superset(self):
1098        result = self.set.isdisjoint(set([2, 4, 6, 8]))
1099        self.assertEqual(result, False)
1100
1101    def test_isdisjoint_overlap(self):
1102        result = self.set.isdisjoint(set([3, 4, 5]))
1103        self.assertEqual(result, False)
1104
1105    def test_isdisjoint_non_overlap(self):
1106        result = self.set.isdisjoint(set([8]))
1107        self.assertEqual(result, True)
1108
1109    def test_sym_difference_subset(self):
1110        result = self.set ^ set((2, 4))
1111        self.assertEqual(result, set([6]))
1112
1113    def test_sym_difference_superset(self):
1114        result = self.set ^ set((2, 4, 6, 8))
1115        self.assertEqual(result, set([8]))
1116
1117    def test_sym_difference_overlap(self):
1118        result = self.set ^ set((3, 4, 5))
1119        self.assertEqual(result, set([2, 3, 5, 6]))
1120
1121    def test_sym_difference_non_overlap(self):
1122        result = self.set ^ set([8])
1123        self.assertEqual(result, set([2, 4, 6, 8]))
1124
1125#==============================================================================
1126
1127class TestUpdateOps(unittest.TestCase):
1128    def setUp(self):
1129        self.set = set((2, 4, 6))
1130
1131    def test_union_subset(self):
1132        self.set |= set([2])
1133        self.assertEqual(self.set, set((2, 4, 6)))
1134
1135    def test_union_superset(self):
1136        self.set |= set([2, 4, 6, 8])
1137        self.assertEqual(self.set, set([2, 4, 6, 8]))
1138
1139    def test_union_overlap(self):
1140        self.set |= set([3, 4, 5])
1141        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1142
1143    def test_union_non_overlap(self):
1144        self.set |= set([8])
1145        self.assertEqual(self.set, set([2, 4, 6, 8]))
1146
1147    def test_union_method_call(self):
1148        self.set.update(set([3, 4, 5]))
1149        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1150
1151    def test_intersection_subset(self):
1152        self.set &= set((2, 4))
1153        self.assertEqual(self.set, set((2, 4)))
1154
1155    def test_intersection_superset(self):
1156        self.set &= set([2, 4, 6, 8])
1157        self.assertEqual(self.set, set([2, 4, 6]))
1158
1159    def test_intersection_overlap(self):
1160        self.set &= set([3, 4, 5])
1161        self.assertEqual(self.set, set([4]))
1162
1163    def test_intersection_non_overlap(self):
1164        self.set &= set([8])
1165        self.assertEqual(self.set, empty_set)
1166
1167    def test_intersection_method_call(self):
1168        self.set.intersection_update(set([3, 4, 5]))
1169        self.assertEqual(self.set, set([4]))
1170
1171    def test_sym_difference_subset(self):
1172        self.set ^= set((2, 4))
1173        self.assertEqual(self.set, set([6]))
1174
1175    def test_sym_difference_superset(self):
1176        self.set ^= set((2, 4, 6, 8))
1177        self.assertEqual(self.set, set([8]))
1178
1179    def test_sym_difference_overlap(self):
1180        self.set ^= set((3, 4, 5))
1181        self.assertEqual(self.set, set([2, 3, 5, 6]))
1182
1183    def test_sym_difference_non_overlap(self):
1184        self.set ^= set([8])
1185        self.assertEqual(self.set, set([2, 4, 6, 8]))
1186
1187    def test_sym_difference_method_call(self):
1188        self.set.symmetric_difference_update(set([3, 4, 5]))
1189        self.assertEqual(self.set, set([2, 3, 5, 6]))
1190
1191    def test_difference_subset(self):
1192        self.set -= set((2, 4))
1193        self.assertEqual(self.set, set([6]))
1194
1195    def test_difference_superset(self):
1196        self.set -= set((2, 4, 6, 8))
1197        self.assertEqual(self.set, set([]))
1198
1199    def test_difference_overlap(self):
1200        self.set -= set((3, 4, 5))
1201        self.assertEqual(self.set, set([2, 6]))
1202
1203    def test_difference_non_overlap(self):
1204        self.set -= set([8])
1205        self.assertEqual(self.set, set([2, 4, 6]))
1206
1207    def test_difference_method_call(self):
1208        self.set.difference_update(set([3, 4, 5]))
1209        self.assertEqual(self.set, set([2, 6]))
1210
1211#==============================================================================
1212
1213class TestMutate(unittest.TestCase):
1214    def setUp(self):
1215        self.values = ["a", "b", "c"]
1216        self.set = set(self.values)
1217
1218    def test_add_present(self):
1219        self.set.add("c")
1220        self.assertEqual(self.set, set("abc"))
1221
1222    def test_add_absent(self):
1223        self.set.add("d")
1224        self.assertEqual(self.set, set("abcd"))
1225
1226    def test_add_until_full(self):
1227        tmp = set()
1228        expected_len = 0
1229        for v in self.values:
1230            tmp.add(v)
1231            expected_len += 1
1232            self.assertEqual(len(tmp), expected_len)
1233        self.assertEqual(tmp, self.set)
1234
1235    def test_remove_present(self):
1236        self.set.remove("b")
1237        self.assertEqual(self.set, set("ac"))
1238
1239    def test_remove_absent(self):
1240        try:
1241            self.set.remove("d")
1242            self.fail("Removing missing element should have raised LookupError")
1243        except LookupError:
1244            pass
1245
1246    def test_remove_until_empty(self):
1247        expected_len = len(self.set)
1248        for v in self.values:
1249            self.set.remove(v)
1250            expected_len -= 1
1251            self.assertEqual(len(self.set), expected_len)
1252
1253    def test_discard_present(self):
1254        self.set.discard("c")
1255        self.assertEqual(self.set, set("ab"))
1256
1257    def test_discard_absent(self):
1258        self.set.discard("d")
1259        self.assertEqual(self.set, set("abc"))
1260
1261    def test_clear(self):
1262        self.set.clear()
1263        self.assertEqual(len(self.set), 0)
1264
1265    def test_pop(self):
1266        popped = {}
1267        while self.set:
1268            popped[self.set.pop()] = None
1269        self.assertEqual(len(popped), len(self.values))
1270        for v in self.values:
1271            self.assertIn(v, popped)
1272
1273    def test_update_empty_tuple(self):
1274        self.set.update(())
1275        self.assertEqual(self.set, set(self.values))
1276
1277    def test_update_unit_tuple_overlap(self):
1278        self.set.update(("a",))
1279        self.assertEqual(self.set, set(self.values))
1280
1281    def test_update_unit_tuple_non_overlap(self):
1282        self.set.update(("a", "z"))
1283        self.assertEqual(self.set, set(self.values + ["z"]))
1284
1285#==============================================================================
1286
1287class TestSubsets:
1288
1289    case2method = {"<=": "issubset",
1290                   ">=": "issuperset",
1291                  }
1292
1293    reverse = {"==": "==",
1294               "!=": "!=",
1295               "<":  ">",
1296               ">":  "<",
1297               "<=": ">=",
1298               ">=": "<=",
1299              }
1300
1301    def test_issubset(self):
1302        x = self.left
1303        y = self.right
1304        for case in "!=", "==", "<", "<=", ">", ">=":
1305            expected = case in self.cases
1306            # Test the binary infix spelling.
1307            result = eval("x" + case + "y", locals())
1308            self.assertEqual(result, expected)
1309            # Test the "friendly" method-name spelling, if one exists.
1310            if case in TestSubsets.case2method:
1311                method = getattr(x, TestSubsets.case2method[case])
1312                result = method(y)
1313                self.assertEqual(result, expected)
1314
1315            # Now do the same for the operands reversed.
1316            rcase = TestSubsets.reverse[case]
1317            result = eval("y" + rcase + "x", locals())
1318            self.assertEqual(result, expected)
1319            if rcase in TestSubsets.case2method:
1320                method = getattr(y, TestSubsets.case2method[rcase])
1321                result = method(x)
1322                self.assertEqual(result, expected)
1323#------------------------------------------------------------------------------
1324
1325class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
1326    left  = set()
1327    right = set()
1328    name  = "both empty"
1329    cases = "==", "<=", ">="
1330
1331#------------------------------------------------------------------------------
1332
1333class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
1334    left  = set([1, 2])
1335    right = set([1, 2])
1336    name  = "equal pair"
1337    cases = "==", "<=", ">="
1338
1339#------------------------------------------------------------------------------
1340
1341class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
1342    left  = set()
1343    right = set([1, 2])
1344    name  = "one empty, one non-empty"
1345    cases = "!=", "<", "<="
1346
1347#------------------------------------------------------------------------------
1348
1349class TestSubsetPartial(TestSubsets, unittest.TestCase):
1350    left  = set([1])
1351    right = set([1, 2])
1352    name  = "one a non-empty proper subset of other"
1353    cases = "!=", "<", "<="
1354
1355#------------------------------------------------------------------------------
1356
1357class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
1358    left  = set([1])
1359    right = set([2])
1360    name  = "neither empty, neither contains"
1361    cases = "!="
1362
1363#==============================================================================
1364
1365class TestOnlySetsInBinaryOps:
1366
1367    def test_eq_ne(self):
1368        # Unlike the others, this is testing that == and != *are* allowed.
1369        self.assertEqual(self.other == self.set, False)
1370        self.assertEqual(self.set == self.other, False)
1371        self.assertEqual(self.other != self.set, True)
1372        self.assertEqual(self.set != self.other, True)
1373
1374    def test_ge_gt_le_lt(self):
1375        self.assertRaises(TypeError, lambda: self.set < self.other)
1376        self.assertRaises(TypeError, lambda: self.set <= self.other)
1377        self.assertRaises(TypeError, lambda: self.set > self.other)
1378        self.assertRaises(TypeError, lambda: self.set >= self.other)
1379
1380        self.assertRaises(TypeError, lambda: self.other < self.set)
1381        self.assertRaises(TypeError, lambda: self.other <= self.set)
1382        self.assertRaises(TypeError, lambda: self.other > self.set)
1383        self.assertRaises(TypeError, lambda: self.other >= self.set)
1384
1385    def test_update_operator(self):
1386        try:
1387            self.set |= self.other
1388        except TypeError:
1389            pass
1390        else:
1391            self.fail("expected TypeError")
1392
1393    def test_update(self):
1394        if self.otherIsIterable:
1395            self.set.update(self.other)
1396        else:
1397            self.assertRaises(TypeError, self.set.update, self.other)
1398
1399    def test_union(self):
1400        self.assertRaises(TypeError, lambda: self.set | self.other)
1401        self.assertRaises(TypeError, lambda: self.other | self.set)
1402        if self.otherIsIterable:
1403            self.set.union(self.other)
1404        else:
1405            self.assertRaises(TypeError, self.set.union, self.other)
1406
1407    def test_intersection_update_operator(self):
1408        try:
1409            self.set &= self.other
1410        except TypeError:
1411            pass
1412        else:
1413            self.fail("expected TypeError")
1414
1415    def test_intersection_update(self):
1416        if self.otherIsIterable:
1417            self.set.intersection_update(self.other)
1418        else:
1419            self.assertRaises(TypeError,
1420                              self.set.intersection_update,
1421                              self.other)
1422
1423    def test_intersection(self):
1424        self.assertRaises(TypeError, lambda: self.set & self.other)
1425        self.assertRaises(TypeError, lambda: self.other & self.set)
1426        if self.otherIsIterable:
1427            self.set.intersection(self.other)
1428        else:
1429            self.assertRaises(TypeError, self.set.intersection, self.other)
1430
1431    def test_sym_difference_update_operator(self):
1432        try:
1433            self.set ^= self.other
1434        except TypeError:
1435            pass
1436        else:
1437            self.fail("expected TypeError")
1438
1439    def test_sym_difference_update(self):
1440        if self.otherIsIterable:
1441            self.set.symmetric_difference_update(self.other)
1442        else:
1443            self.assertRaises(TypeError,
1444                              self.set.symmetric_difference_update,
1445                              self.other)
1446
1447    def test_sym_difference(self):
1448        self.assertRaises(TypeError, lambda: self.set ^ self.other)
1449        self.assertRaises(TypeError, lambda: self.other ^ self.set)
1450        if self.otherIsIterable:
1451            self.set.symmetric_difference(self.other)
1452        else:
1453            self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
1454
1455    def test_difference_update_operator(self):
1456        try:
1457            self.set -= self.other
1458        except TypeError:
1459            pass
1460        else:
1461            self.fail("expected TypeError")
1462
1463    def test_difference_update(self):
1464        if self.otherIsIterable:
1465            self.set.difference_update(self.other)
1466        else:
1467            self.assertRaises(TypeError,
1468                              self.set.difference_update,
1469                              self.other)
1470
1471    def test_difference(self):
1472        self.assertRaises(TypeError, lambda: self.set - self.other)
1473        self.assertRaises(TypeError, lambda: self.other - self.set)
1474        if self.otherIsIterable:
1475            self.set.difference(self.other)
1476        else:
1477            self.assertRaises(TypeError, self.set.difference, self.other)
1478
1479#------------------------------------------------------------------------------
1480
1481class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
1482    def setUp(self):
1483        self.set   = set((1, 2, 3))
1484        self.other = 19
1485        self.otherIsIterable = False
1486
1487#------------------------------------------------------------------------------
1488
1489class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
1490    def setUp(self):
1491        self.set   = set((1, 2, 3))
1492        self.other = {1:2, 3:4}
1493        self.otherIsIterable = True
1494
1495#------------------------------------------------------------------------------
1496
1497class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
1498    def setUp(self):
1499        self.set   = set((1, 2, 3))
1500        self.other = operator.add
1501        self.otherIsIterable = False
1502
1503#------------------------------------------------------------------------------
1504
1505class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
1506    def setUp(self):
1507        self.set   = set((1, 2, 3))
1508        self.other = (2, 4, 6)
1509        self.otherIsIterable = True
1510
1511#------------------------------------------------------------------------------
1512
1513class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
1514    def setUp(self):
1515        self.set   = set((1, 2, 3))
1516        self.other = 'abc'
1517        self.otherIsIterable = True
1518
1519#------------------------------------------------------------------------------
1520
1521class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
1522    def setUp(self):
1523        def gen():
1524            for i in range(0, 10, 2):
1525                yield i
1526        self.set   = set((1, 2, 3))
1527        self.other = gen()
1528        self.otherIsIterable = True
1529
1530#==============================================================================
1531
1532class TestCopying:
1533
1534    def test_copy(self):
1535        dup = self.set.copy()
1536        dup_list = sorted(dup, key=repr)
1537        set_list = sorted(self.set, key=repr)
1538        self.assertEqual(len(dup_list), len(set_list))
1539        for i in range(len(dup_list)):
1540            self.assertTrue(dup_list[i] is set_list[i])
1541
1542    def test_deep_copy(self):
1543        dup = copy.deepcopy(self.set)
1544        ##print type(dup), repr(dup)
1545        dup_list = sorted(dup, key=repr)
1546        set_list = sorted(self.set, key=repr)
1547        self.assertEqual(len(dup_list), len(set_list))
1548        for i in range(len(dup_list)):
1549            self.assertEqual(dup_list[i], set_list[i])
1550
1551#------------------------------------------------------------------------------
1552
1553class TestCopyingEmpty(TestCopying, unittest.TestCase):
1554    def setUp(self):
1555        self.set = set()
1556
1557#------------------------------------------------------------------------------
1558
1559class TestCopyingSingleton(TestCopying, unittest.TestCase):
1560    def setUp(self):
1561        self.set = set(["hello"])
1562
1563#------------------------------------------------------------------------------
1564
1565class TestCopyingTriple(TestCopying, unittest.TestCase):
1566    def setUp(self):
1567        self.set = set(["zero", 0, None])
1568
1569#------------------------------------------------------------------------------
1570
1571class TestCopyingTuple(TestCopying, unittest.TestCase):
1572    def setUp(self):
1573        self.set = set([(1, 2)])
1574
1575#------------------------------------------------------------------------------
1576
1577class TestCopyingNested(TestCopying, unittest.TestCase):
1578    def setUp(self):
1579        self.set = set([((1, 2), (3, 4))])
1580
1581#==============================================================================
1582
1583class TestIdentities(unittest.TestCase):
1584    def setUp(self):
1585        self.a = set('abracadabra')
1586        self.b = set('alacazam')
1587
1588    def test_binopsVsSubsets(self):
1589        a, b = self.a, self.b
1590        self.assertTrue(a - b < a)
1591        self.assertTrue(b - a < b)
1592        self.assertTrue(a & b < a)
1593        self.assertTrue(a & b < b)
1594        self.assertTrue(a | b > a)
1595        self.assertTrue(a | b > b)
1596        self.assertTrue(a ^ b < a | b)
1597
1598    def test_commutativity(self):
1599        a, b = self.a, self.b
1600        self.assertEqual(a&b, b&a)
1601        self.assertEqual(a|b, b|a)
1602        self.assertEqual(a^b, b^a)
1603        if a != b:
1604            self.assertNotEqual(a-b, b-a)
1605
1606    def test_summations(self):
1607        # check that sums of parts equal the whole
1608        a, b = self.a, self.b
1609        self.assertEqual((a-b)|(a&b)|(b-a), a|b)
1610        self.assertEqual((a&b)|(a^b), a|b)
1611        self.assertEqual(a|(b-a), a|b)
1612        self.assertEqual((a-b)|b, a|b)
1613        self.assertEqual((a-b)|(a&b), a)
1614        self.assertEqual((b-a)|(a&b), b)
1615        self.assertEqual((a-b)|(b-a), a^b)
1616
1617    def test_exclusion(self):
1618        # check that inverse operations show non-overlap
1619        a, b, zero = self.a, self.b, set()
1620        self.assertEqual((a-b)&b, zero)
1621        self.assertEqual((b-a)&a, zero)
1622        self.assertEqual((a&b)&(a^b), zero)
1623
1624# Tests derived from test_itertools.py =======================================
1625
1626def R(seqn):
1627    'Regular generator'
1628    for i in seqn:
1629        yield i
1630
1631class G:
1632    'Sequence using __getitem__'
1633    def __init__(self, seqn):
1634        self.seqn = seqn
1635    def __getitem__(self, i):
1636        return self.seqn[i]
1637
1638class I:
1639    'Sequence using iterator protocol'
1640    def __init__(self, seqn):
1641        self.seqn = seqn
1642        self.i = 0
1643    def __iter__(self):
1644        return self
1645    def __next__(self):
1646        if self.i >= len(self.seqn): raise StopIteration
1647        v = self.seqn[self.i]
1648        self.i += 1
1649        return v
1650
1651class Ig:
1652    'Sequence using iterator protocol defined with a generator'
1653    def __init__(self, seqn):
1654        self.seqn = seqn
1655        self.i = 0
1656    def __iter__(self):
1657        for val in self.seqn:
1658            yield val
1659
1660class X:
1661    'Missing __getitem__ and __iter__'
1662    def __init__(self, seqn):
1663        self.seqn = seqn
1664        self.i = 0
1665    def __next__(self):
1666        if self.i >= len(self.seqn): raise StopIteration
1667        v = self.seqn[self.i]
1668        self.i += 1
1669        return v
1670
1671class N:
1672    'Iterator missing __next__()'
1673    def __init__(self, seqn):
1674        self.seqn = seqn
1675        self.i = 0
1676    def __iter__(self):
1677        return self
1678
1679class E:
1680    'Test propagation of exceptions'
1681    def __init__(self, seqn):
1682        self.seqn = seqn
1683        self.i = 0
1684    def __iter__(self):
1685        return self
1686    def __next__(self):
1687        3 // 0
1688
1689class S:
1690    'Test immediate stop'
1691    def __init__(self, seqn):
1692        pass
1693    def __iter__(self):
1694        return self
1695    def __next__(self):
1696        raise StopIteration
1697
1698from itertools import chain
1699def L(seqn):
1700    'Test multiple tiers of iterators'
1701    return chain(map(lambda x:x, R(Ig(G(seqn)))))
1702
1703class TestVariousIteratorArgs(unittest.TestCase):
1704
1705    def test_constructor(self):
1706        for cons in (set, frozenset):
1707            for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
1708                for g in (G, I, Ig, S, L, R):
1709                    self.assertEqual(sorted(cons(g(s)), key=repr), sorted(g(s), key=repr))
1710                self.assertRaises(TypeError, cons , X(s))
1711                self.assertRaises(TypeError, cons , N(s))
1712                self.assertRaises(ZeroDivisionError, cons , E(s))
1713
1714    def test_inline_methods(self):
1715        s = set('november')
1716        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1717            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
1718                for g in (G, I, Ig, L, R):
1719                    expected = meth(data)
1720                    actual = meth(g(data))
1721                    if isinstance(expected, bool):
1722                        self.assertEqual(actual, expected)
1723                    else:
1724                        self.assertEqual(sorted(actual, key=repr), sorted(expected, key=repr))
1725                self.assertRaises(TypeError, meth, X(s))
1726                self.assertRaises(TypeError, meth, N(s))
1727                self.assertRaises(ZeroDivisionError, meth, E(s))
1728
1729    def test_inplace_methods(self):
1730        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1731            for methname in ('update', 'intersection_update',
1732                             'difference_update', 'symmetric_difference_update'):
1733                for g in (G, I, Ig, S, L, R):
1734                    s = set('january')
1735                    t = s.copy()
1736                    getattr(s, methname)(list(g(data)))
1737                    getattr(t, methname)(g(data))
1738                    self.assertEqual(sorted(s, key=repr), sorted(t, key=repr))
1739
1740                self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
1741                self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
1742                self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
1743
1744class bad_eq:
1745    def __eq__(self, other):
1746        if be_bad:
1747            set2.clear()
1748            raise ZeroDivisionError
1749        return self is other
1750    def __hash__(self):
1751        return 0
1752
1753class bad_dict_clear:
1754    def __eq__(self, other):
1755        if be_bad:
1756            dict2.clear()
1757        return self is other
1758    def __hash__(self):
1759        return 0
1760
1761class TestWeirdBugs(unittest.TestCase):
1762    def test_8420_set_merge(self):
1763        # This used to segfault
1764        global be_bad, set2, dict2
1765        be_bad = False
1766        set1 = {bad_eq()}
1767        set2 = {bad_eq() for i in range(75)}
1768        be_bad = True
1769        self.assertRaises(ZeroDivisionError, set1.update, set2)
1770
1771        be_bad = False
1772        set1 = {bad_dict_clear()}
1773        dict2 = {bad_dict_clear(): None}
1774        be_bad = True
1775        set1.symmetric_difference_update(dict2)
1776
1777    def test_iter_and_mutate(self):
1778        # Issue #24581
1779        s = set(range(100))
1780        s.clear()
1781        s.update(range(100))
1782        si = iter(s)
1783        s.clear()
1784        a = list(range(100))
1785        s.update(range(100))
1786        list(si)
1787
1788    def test_merge_and_mutate(self):
1789        class X:
1790            def __hash__(self):
1791                return hash(0)
1792            def __eq__(self, o):
1793                other.clear()
1794                return False
1795
1796        other = set()
1797        other = {X() for i in range(10)}
1798        s = {0}
1799        s.update(other)
1800
1801# Application tests (based on David Eppstein's graph recipes ====================================
1802
1803def powerset(U):
1804    """Generates all subsets of a set or sequence U."""
1805    U = iter(U)
1806    try:
1807        x = frozenset([next(U)])
1808        for S in powerset(U):
1809            yield S
1810            yield S | x
1811    except StopIteration:
1812        yield frozenset()
1813
1814def cube(n):
1815    """Graph of n-dimensional hypercube."""
1816    singletons = [frozenset([x]) for x in range(n)]
1817    return dict([(x, frozenset([x^s for s in singletons]))
1818                 for x in powerset(range(n))])
1819
1820def linegraph(G):
1821    """Graph, the vertices of which are edges of G,
1822    with two vertices being adjacent iff the corresponding
1823    edges share a vertex."""
1824    L = {}
1825    for x in G:
1826        for y in G[x]:
1827            nx = [frozenset([x,z]) for z in G[x] if z != y]
1828            ny = [frozenset([y,z]) for z in G[y] if z != x]
1829            L[frozenset([x,y])] = frozenset(nx+ny)
1830    return L
1831
1832def faces(G):
1833    'Return a set of faces in G.  Where a face is a set of vertices on that face'
1834    # currently limited to triangles,squares, and pentagons
1835    f = set()
1836    for v1, edges in G.items():
1837        for v2 in edges:
1838            for v3 in G[v2]:
1839                if v1 == v3:
1840                    continue
1841                if v1 in G[v3]:
1842                    f.add(frozenset([v1, v2, v3]))
1843                else:
1844                    for v4 in G[v3]:
1845                        if v4 == v2:
1846                            continue
1847                        if v1 in G[v4]:
1848                            f.add(frozenset([v1, v2, v3, v4]))
1849                        else:
1850                            for v5 in G[v4]:
1851                                if v5 == v3 or v5 == v2:
1852                                    continue
1853                                if v1 in G[v5]:
1854                                    f.add(frozenset([v1, v2, v3, v4, v5]))
1855    return f
1856
1857
1858class TestGraphs(unittest.TestCase):
1859
1860    def test_cube(self):
1861
1862        g = cube(3)                             # vert --> {v1, v2, v3}
1863        vertices1 = set(g)
1864        self.assertEqual(len(vertices1), 8)     # eight vertices
1865        for edge in g.values():
1866            self.assertEqual(len(edge), 3)      # each vertex connects to three edges
1867        vertices2 = set(v for edges in g.values() for v in edges)
1868        self.assertEqual(vertices1, vertices2)  # edge vertices in original set
1869
1870        cubefaces = faces(g)
1871        self.assertEqual(len(cubefaces), 6)     # six faces
1872        for face in cubefaces:
1873            self.assertEqual(len(face), 4)      # each face is a square
1874
1875    def test_cuboctahedron(self):
1876
1877        # http://en.wikipedia.org/wiki/Cuboctahedron
1878        # 8 triangular faces and 6 square faces
1879        # 12 identical vertices each connecting a triangle and square
1880
1881        g = cube(3)
1882        cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
1883        self.assertEqual(len(cuboctahedron), 12)# twelve vertices
1884
1885        vertices = set(cuboctahedron)
1886        for edges in cuboctahedron.values():
1887            self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
1888        othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
1889        self.assertEqual(vertices, othervertices)   # edge vertices in original set
1890
1891        cubofaces = faces(cuboctahedron)
1892        facesizes = collections.defaultdict(int)
1893        for face in cubofaces:
1894            facesizes[len(face)] += 1
1895        self.assertEqual(facesizes[3], 8)       # eight triangular faces
1896        self.assertEqual(facesizes[4], 6)       # six square faces
1897
1898        for vertex in cuboctahedron:
1899            edge = vertex                       # Cuboctahedron vertices are edges in Cube
1900            self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
1901            for cubevert in edge:
1902                self.assertIn(cubevert, g)
1903
1904
1905#==============================================================================
1906
1907if __name__ == "__main__":
1908    unittest.main()
1909