• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from test import support
3from test.support import warnings_helper
4import gc
5import weakref
6import operator
7import copy
8import pickle
9from random import randrange, shuffle
10import warnings
11import collections
12import collections.abc
13import itertools
14
15class PassThru(Exception):
16    pass
17
18def check_pass_thru():
19    raise PassThru
20    yield 1
21
22class BadCmp:
23    def __hash__(self):
24        return 1
25    def __eq__(self, other):
26        raise RuntimeError
27
28class ReprWrapper:
29    'Used to test self-referential repr() calls'
30    def __repr__(self):
31        return repr(self.value)
32
33class HashCountingInt(int):
34    'int-like object that counts the number of times __hash__ is called'
35    def __init__(self, *args):
36        self.hash_count = 0
37    def __hash__(self):
38        self.hash_count += 1
39        return int.__hash__(self)
40
41class TestJointOps:
42    # Tests common to both set and frozenset
43
44    def setUp(self):
45        self.word = word = 'simsalabim'
46        self.otherword = 'madagascar'
47        self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
48        self.s = self.thetype(word)
49        self.d = dict.fromkeys(word)
50
51    def test_new_or_init(self):
52        self.assertRaises(TypeError, self.thetype, [], 2)
53        self.assertRaises(TypeError, set().__init__, a=1)
54
55    def test_uniquification(self):
56        actual = sorted(self.s)
57        expected = sorted(self.d)
58        self.assertEqual(actual, expected)
59        self.assertRaises(PassThru, self.thetype, check_pass_thru())
60        self.assertRaises(TypeError, self.thetype, [[]])
61
62    def test_len(self):
63        self.assertEqual(len(self.s), len(self.d))
64
65    def test_contains(self):
66        for c in self.letters:
67            self.assertEqual(c in self.s, c in self.d)
68        self.assertRaises(TypeError, self.s.__contains__, [[]])
69        s = self.thetype([frozenset(self.letters)])
70        self.assertIn(self.thetype(self.letters), s)
71
72    def test_union(self):
73        u = self.s.union(self.otherword)
74        for c in self.letters:
75            self.assertEqual(c in u, c in self.d or c in self.otherword)
76        self.assertEqual(self.s, self.thetype(self.word))
77        self.assertEqual(type(u), self.basetype)
78        self.assertRaises(PassThru, self.s.union, check_pass_thru())
79        self.assertRaises(TypeError, self.s.union, [[]])
80        for C in set, frozenset, dict.fromkeys, str, list, tuple:
81            self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
82            self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
83            self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
84            self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
85            self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
86
87        # Issue #6573
88        x = self.thetype()
89        self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
90
91    def test_or(self):
92        i = self.s.union(self.otherword)
93        self.assertEqual(self.s | set(self.otherword), i)
94        self.assertEqual(self.s | frozenset(self.otherword), i)
95        try:
96            self.s | self.otherword
97        except TypeError:
98            pass
99        else:
100            self.fail("s|t did not screen-out general iterables")
101
102    def test_intersection(self):
103        i = self.s.intersection(self.otherword)
104        for c in self.letters:
105            self.assertEqual(c in i, c in self.d and c in self.otherword)
106        self.assertEqual(self.s, self.thetype(self.word))
107        self.assertEqual(type(i), self.basetype)
108        self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
109        for C in set, frozenset, dict.fromkeys, str, list, tuple:
110            self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
111            self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
112            self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
113            self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
114            self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
115        s = self.thetype('abcba')
116        z = s.intersection()
117        if self.thetype == frozenset():
118            self.assertEqual(id(s), id(z))
119        else:
120            self.assertNotEqual(id(s), id(z))
121
122    def test_isdisjoint(self):
123        def f(s1, s2):
124            'Pure python equivalent of isdisjoint()'
125            return not set(s1).intersection(s2)
126        for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
127            s1 = self.thetype(larg)
128            for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
129                for C in set, frozenset, dict.fromkeys, str, list, tuple:
130                    s2 = C(rarg)
131                    actual = s1.isdisjoint(s2)
132                    expected = f(s1, s2)
133                    self.assertEqual(actual, expected)
134                    self.assertTrue(actual is True or actual is False)
135
136    def test_and(self):
137        i = self.s.intersection(self.otherword)
138        self.assertEqual(self.s & set(self.otherword), i)
139        self.assertEqual(self.s & frozenset(self.otherword), i)
140        try:
141            self.s & self.otherword
142        except TypeError:
143            pass
144        else:
145            self.fail("s&t did not screen-out general iterables")
146
147    def test_difference(self):
148        i = self.s.difference(self.otherword)
149        for c in self.letters:
150            self.assertEqual(c in i, c in self.d and c not in self.otherword)
151        self.assertEqual(self.s, self.thetype(self.word))
152        self.assertEqual(type(i), self.basetype)
153        self.assertRaises(PassThru, self.s.difference, check_pass_thru())
154        self.assertRaises(TypeError, self.s.difference, [[]])
155        for C in set, frozenset, dict.fromkeys, str, list, tuple:
156            self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
157            self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
158            self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
159            self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
160            self.assertEqual(self.thetype('abcba').difference(), set('abc'))
161            self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
162
163    def test_sub(self):
164        i = self.s.difference(self.otherword)
165        self.assertEqual(self.s - set(self.otherword), i)
166        self.assertEqual(self.s - frozenset(self.otherword), i)
167        try:
168            self.s - self.otherword
169        except TypeError:
170            pass
171        else:
172            self.fail("s-t did not screen-out general iterables")
173
174    def test_symmetric_difference(self):
175        i = self.s.symmetric_difference(self.otherword)
176        for c in self.letters:
177            self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
178        self.assertEqual(self.s, self.thetype(self.word))
179        self.assertEqual(type(i), self.basetype)
180        self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
181        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
182        for C in set, frozenset, dict.fromkeys, str, list, tuple:
183            self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
184            self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
185            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
186            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
187
188    def test_xor(self):
189        i = self.s.symmetric_difference(self.otherword)
190        self.assertEqual(self.s ^ set(self.otherword), i)
191        self.assertEqual(self.s ^ frozenset(self.otherword), i)
192        try:
193            self.s ^ self.otherword
194        except TypeError:
195            pass
196        else:
197            self.fail("s^t did not screen-out general iterables")
198
199    def test_equality(self):
200        self.assertEqual(self.s, set(self.word))
201        self.assertEqual(self.s, frozenset(self.word))
202        self.assertEqual(self.s == self.word, False)
203        self.assertNotEqual(self.s, set(self.otherword))
204        self.assertNotEqual(self.s, frozenset(self.otherword))
205        self.assertEqual(self.s != self.word, True)
206
207    def test_setOfFrozensets(self):
208        t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
209        s = self.thetype(t)
210        self.assertEqual(len(s), 3)
211
212    def test_sub_and_super(self):
213        p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
214        self.assertTrue(p < q)
215        self.assertTrue(p <= q)
216        self.assertTrue(q <= q)
217        self.assertTrue(q > p)
218        self.assertTrue(q >= p)
219        self.assertFalse(q < r)
220        self.assertFalse(q <= r)
221        self.assertFalse(q > r)
222        self.assertFalse(q >= r)
223        self.assertTrue(set('a').issubset('abc'))
224        self.assertTrue(set('abc').issuperset('a'))
225        self.assertFalse(set('a').issubset('cbs'))
226        self.assertFalse(set('cbs').issuperset('a'))
227
228    def test_pickling(self):
229        for i in range(pickle.HIGHEST_PROTOCOL + 1):
230            p = pickle.dumps(self.s, i)
231            dup = pickle.loads(p)
232            self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
233            if type(self.s) not in (set, frozenset):
234                self.s.x = 10
235                p = pickle.dumps(self.s, i)
236                dup = pickle.loads(p)
237                self.assertEqual(self.s.x, dup.x)
238
239    def test_iterator_pickling(self):
240        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
241            itorg = iter(self.s)
242            data = self.thetype(self.s)
243            d = pickle.dumps(itorg, proto)
244            it = pickle.loads(d)
245            # Set iterators unpickle as list iterators due to the
246            # undefined order of set items.
247            # self.assertEqual(type(itorg), type(it))
248            self.assertIsInstance(it, collections.abc.Iterator)
249            self.assertEqual(self.thetype(it), data)
250
251            it = pickle.loads(d)
252            try:
253                drop = next(it)
254            except StopIteration:
255                continue
256            d = pickle.dumps(it, proto)
257            it = pickle.loads(d)
258            self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
259
260    def test_deepcopy(self):
261        class Tracer:
262            def __init__(self, value):
263                self.value = value
264            def __hash__(self):
265                return self.value
266            def __deepcopy__(self, memo=None):
267                return Tracer(self.value + 1)
268        t = Tracer(10)
269        s = self.thetype([t])
270        dup = copy.deepcopy(s)
271        self.assertNotEqual(id(s), id(dup))
272        for elem in dup:
273            newt = elem
274        self.assertNotEqual(id(t), id(newt))
275        self.assertEqual(t.value + 1, newt.value)
276
277    def test_gc(self):
278        # Create a nest of cycles to exercise overall ref count check
279        class A:
280            pass
281        s = set(A() for i in range(1000))
282        for elem in s:
283            elem.cycle = s
284            elem.sub = elem
285            elem.set = set([elem])
286
287    def test_subclass_with_custom_hash(self):
288        # Bug #1257731
289        class H(self.thetype):
290            def __hash__(self):
291                return int(id(self) & 0x7fffffff)
292        s=H()
293        f=set()
294        f.add(s)
295        self.assertIn(s, f)
296        f.remove(s)
297        f.add(s)
298        f.discard(s)
299
300    def test_badcmp(self):
301        s = self.thetype([BadCmp()])
302        # Detect comparison errors during insertion and lookup
303        self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()])
304        self.assertRaises(RuntimeError, s.__contains__, BadCmp())
305        # Detect errors during mutating operations
306        if hasattr(s, 'add'):
307            self.assertRaises(RuntimeError, s.add, BadCmp())
308            self.assertRaises(RuntimeError, s.discard, BadCmp())
309            self.assertRaises(RuntimeError, s.remove, BadCmp())
310
311    def test_cyclical_repr(self):
312        w = ReprWrapper()
313        s = self.thetype([w])
314        w.value = s
315        if self.thetype == set:
316            self.assertEqual(repr(s), '{set(...)}')
317        else:
318            name = repr(s).partition('(')[0]    # strip class name
319            self.assertEqual(repr(s), '%s({%s(...)})' % (name, name))
320
321    def test_do_not_rehash_dict_keys(self):
322        n = 10
323        d = dict.fromkeys(map(HashCountingInt, range(n)))
324        self.assertEqual(sum(elem.hash_count for elem in d), n)
325        s = self.thetype(d)
326        self.assertEqual(sum(elem.hash_count for elem in d), n)
327        s.difference(d)
328        self.assertEqual(sum(elem.hash_count for elem in d), n)
329        if hasattr(s, 'symmetric_difference_update'):
330            s.symmetric_difference_update(d)
331        self.assertEqual(sum(elem.hash_count for elem in d), n)
332        d2 = dict.fromkeys(set(d))
333        self.assertEqual(sum(elem.hash_count for elem in d), n)
334        d3 = dict.fromkeys(frozenset(d))
335        self.assertEqual(sum(elem.hash_count for elem in d), n)
336        d3 = dict.fromkeys(frozenset(d), 123)
337        self.assertEqual(sum(elem.hash_count for elem in d), n)
338        self.assertEqual(d3, dict.fromkeys(d, 123))
339
340    def test_container_iterator(self):
341        # Bug #3680: tp_traverse was not implemented for set iterator object
342        class C(object):
343            pass
344        obj = C()
345        ref = weakref.ref(obj)
346        container = set([obj, 1])
347        obj.x = iter(container)
348        del obj, container
349        gc.collect()
350        self.assertTrue(ref() is None, "Cycle was not collected")
351
352    def test_free_after_iterating(self):
353        support.check_free_after_iterating(self, iter, self.thetype)
354
355class TestSet(TestJointOps, unittest.TestCase):
356    thetype = set
357    basetype = set
358
359    def test_init(self):
360        s = self.thetype()
361        s.__init__(self.word)
362        self.assertEqual(s, set(self.word))
363        s.__init__(self.otherword)
364        self.assertEqual(s, set(self.otherword))
365        self.assertRaises(TypeError, s.__init__, s, 2)
366        self.assertRaises(TypeError, s.__init__, 1)
367
368    def test_constructor_identity(self):
369        s = self.thetype(range(3))
370        t = self.thetype(s)
371        self.assertNotEqual(id(s), id(t))
372
373    def test_set_literal(self):
374        s = set([1,2,3])
375        t = {1,2,3}
376        self.assertEqual(s, t)
377
378    def test_set_literal_insertion_order(self):
379        # SF Issue #26020 -- Expect left to right insertion
380        s = {1, 1.0, True}
381        self.assertEqual(len(s), 1)
382        stored_value = s.pop()
383        self.assertEqual(type(stored_value), int)
384
385    def test_set_literal_evaluation_order(self):
386        # Expect left to right expression evaluation
387        events = []
388        def record(obj):
389            events.append(obj)
390        s = {record(1), record(2), record(3)}
391        self.assertEqual(events, [1, 2, 3])
392
393    def test_hash(self):
394        self.assertRaises(TypeError, hash, self.s)
395
396    def test_clear(self):
397        self.s.clear()
398        self.assertEqual(self.s, set())
399        self.assertEqual(len(self.s), 0)
400
401    def test_copy(self):
402        dup = self.s.copy()
403        self.assertEqual(self.s, dup)
404        self.assertNotEqual(id(self.s), id(dup))
405        self.assertEqual(type(dup), self.basetype)
406
407    def test_add(self):
408        self.s.add('Q')
409        self.assertIn('Q', self.s)
410        dup = self.s.copy()
411        self.s.add('Q')
412        self.assertEqual(self.s, dup)
413        self.assertRaises(TypeError, self.s.add, [])
414
415    def test_remove(self):
416        self.s.remove('a')
417        self.assertNotIn('a', self.s)
418        self.assertRaises(KeyError, self.s.remove, 'Q')
419        self.assertRaises(TypeError, self.s.remove, [])
420        s = self.thetype([frozenset(self.word)])
421        self.assertIn(self.thetype(self.word), s)
422        s.remove(self.thetype(self.word))
423        self.assertNotIn(self.thetype(self.word), s)
424        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
425
426    def test_remove_keyerror_unpacking(self):
427        # bug:  www.python.org/sf/1576657
428        for v1 in ['Q', (1,)]:
429            try:
430                self.s.remove(v1)
431            except KeyError as e:
432                v2 = e.args[0]
433                self.assertEqual(v1, v2)
434            else:
435                self.fail()
436
437    def test_remove_keyerror_set(self):
438        key = self.thetype([3, 4])
439        try:
440            self.s.remove(key)
441        except KeyError as e:
442            self.assertTrue(e.args[0] is key,
443                         "KeyError should be {0}, not {1}".format(key,
444                                                                  e.args[0]))
445        else:
446            self.fail()
447
448    def test_discard(self):
449        self.s.discard('a')
450        self.assertNotIn('a', self.s)
451        self.s.discard('Q')
452        self.assertRaises(TypeError, self.s.discard, [])
453        s = self.thetype([frozenset(self.word)])
454        self.assertIn(self.thetype(self.word), s)
455        s.discard(self.thetype(self.word))
456        self.assertNotIn(self.thetype(self.word), s)
457        s.discard(self.thetype(self.word))
458
459    def test_pop(self):
460        for i in range(len(self.s)):
461            elem = self.s.pop()
462            self.assertNotIn(elem, self.s)
463        self.assertRaises(KeyError, self.s.pop)
464
465    def test_update(self):
466        retval = self.s.update(self.otherword)
467        self.assertEqual(retval, None)
468        for c in (self.word + self.otherword):
469            self.assertIn(c, self.s)
470        self.assertRaises(PassThru, self.s.update, check_pass_thru())
471        self.assertRaises(TypeError, self.s.update, [[]])
472        for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
473            for C in set, frozenset, dict.fromkeys, str, list, tuple:
474                s = self.thetype('abcba')
475                self.assertEqual(s.update(C(p)), None)
476                self.assertEqual(s, set(q))
477        for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
478            q = 'ahi'
479            for C in set, frozenset, dict.fromkeys, str, list, tuple:
480                s = self.thetype('abcba')
481                self.assertEqual(s.update(C(p), C(q)), None)
482                self.assertEqual(s, set(s) | set(p) | set(q))
483
484    def test_ior(self):
485        self.s |= set(self.otherword)
486        for c in (self.word + self.otherword):
487            self.assertIn(c, self.s)
488
489    def test_intersection_update(self):
490        retval = self.s.intersection_update(self.otherword)
491        self.assertEqual(retval, None)
492        for c in (self.word + self.otherword):
493            if c in self.otherword and c in self.word:
494                self.assertIn(c, self.s)
495            else:
496                self.assertNotIn(c, self.s)
497        self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
498        self.assertRaises(TypeError, self.s.intersection_update, [[]])
499        for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
500            for C in set, frozenset, dict.fromkeys, str, list, tuple:
501                s = self.thetype('abcba')
502                self.assertEqual(s.intersection_update(C(p)), None)
503                self.assertEqual(s, set(q))
504                ss = 'abcba'
505                s = self.thetype(ss)
506                t = 'cbc'
507                self.assertEqual(s.intersection_update(C(p), C(t)), None)
508                self.assertEqual(s, set('abcba')&set(p)&set(t))
509
510    def test_iand(self):
511        self.s &= set(self.otherword)
512        for c in (self.word + self.otherword):
513            if c in self.otherword and c in self.word:
514                self.assertIn(c, self.s)
515            else:
516                self.assertNotIn(c, self.s)
517
518    def test_difference_update(self):
519        retval = self.s.difference_update(self.otherword)
520        self.assertEqual(retval, None)
521        for c in (self.word + self.otherword):
522            if c in self.word and c not in self.otherword:
523                self.assertIn(c, self.s)
524            else:
525                self.assertNotIn(c, self.s)
526        self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
527        self.assertRaises(TypeError, self.s.difference_update, [[]])
528        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
529        for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
530            for C in set, frozenset, dict.fromkeys, str, list, tuple:
531                s = self.thetype('abcba')
532                self.assertEqual(s.difference_update(C(p)), None)
533                self.assertEqual(s, set(q))
534
535                s = self.thetype('abcdefghih')
536                s.difference_update()
537                self.assertEqual(s, self.thetype('abcdefghih'))
538
539                s = self.thetype('abcdefghih')
540                s.difference_update(C('aba'))
541                self.assertEqual(s, self.thetype('cdefghih'))
542
543                s = self.thetype('abcdefghih')
544                s.difference_update(C('cdc'), C('aba'))
545                self.assertEqual(s, self.thetype('efghih'))
546
547    def test_isub(self):
548        self.s -= set(self.otherword)
549        for c in (self.word + self.otherword):
550            if c in self.word and c not in self.otherword:
551                self.assertIn(c, self.s)
552            else:
553                self.assertNotIn(c, self.s)
554
555    def test_symmetric_difference_update(self):
556        retval = self.s.symmetric_difference_update(self.otherword)
557        self.assertEqual(retval, None)
558        for c in (self.word + self.otherword):
559            if (c in self.word) ^ (c in self.otherword):
560                self.assertIn(c, self.s)
561            else:
562                self.assertNotIn(c, self.s)
563        self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
564        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
565        for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
566            for C in set, frozenset, dict.fromkeys, str, list, tuple:
567                s = self.thetype('abcba')
568                self.assertEqual(s.symmetric_difference_update(C(p)), None)
569                self.assertEqual(s, set(q))
570
571    def test_ixor(self):
572        self.s ^= set(self.otherword)
573        for c in (self.word + self.otherword):
574            if (c in self.word) ^ (c in self.otherword):
575                self.assertIn(c, self.s)
576            else:
577                self.assertNotIn(c, self.s)
578
579    def test_inplace_on_self(self):
580        t = self.s.copy()
581        t |= t
582        self.assertEqual(t, self.s)
583        t &= t
584        self.assertEqual(t, self.s)
585        t -= t
586        self.assertEqual(t, self.thetype())
587        t = self.s.copy()
588        t ^= t
589        self.assertEqual(t, self.thetype())
590
591    def test_weakref(self):
592        s = self.thetype('gallahad')
593        p = weakref.proxy(s)
594        self.assertEqual(str(p), str(s))
595        s = None
596        support.gc_collect()  # For PyPy or other GCs.
597        self.assertRaises(ReferenceError, str, p)
598
599    def test_rich_compare(self):
600        class TestRichSetCompare:
601            def __gt__(self, some_set):
602                self.gt_called = True
603                return False
604            def __lt__(self, some_set):
605                self.lt_called = True
606                return False
607            def __ge__(self, some_set):
608                self.ge_called = True
609                return False
610            def __le__(self, some_set):
611                self.le_called = True
612                return False
613
614        # This first tries the builtin rich set comparison, which doesn't know
615        # how to handle the custom object. Upon returning NotImplemented, the
616        # corresponding comparison on the right object is invoked.
617        myset = {1, 2, 3}
618
619        myobj = TestRichSetCompare()
620        myset < myobj
621        self.assertTrue(myobj.gt_called)
622
623        myobj = TestRichSetCompare()
624        myset > myobj
625        self.assertTrue(myobj.lt_called)
626
627        myobj = TestRichSetCompare()
628        myset <= myobj
629        self.assertTrue(myobj.ge_called)
630
631        myobj = TestRichSetCompare()
632        myset >= myobj
633        self.assertTrue(myobj.le_called)
634
635    @unittest.skipUnless(hasattr(set, "test_c_api"),
636                         'C API test only available in a debug build')
637    def test_c_api(self):
638        self.assertEqual(set().test_c_api(), True)
639
640class SetSubclass(set):
641    pass
642
643class TestSetSubclass(TestSet):
644    thetype = SetSubclass
645    basetype = set
646
647class SetSubclassWithKeywordArgs(set):
648    def __init__(self, iterable=[], newarg=None):
649        set.__init__(self, iterable)
650
651class TestSetSubclassWithKeywordArgs(TestSet):
652
653    def test_keywords_in_subclass(self):
654        'SF bug #1486663 -- this used to erroneously raise a TypeError'
655        SetSubclassWithKeywordArgs(newarg=1)
656
657class TestFrozenSet(TestJointOps, unittest.TestCase):
658    thetype = frozenset
659    basetype = frozenset
660
661    def test_init(self):
662        s = self.thetype(self.word)
663        s.__init__(self.otherword)
664        self.assertEqual(s, set(self.word))
665
666    def test_constructor_identity(self):
667        s = self.thetype(range(3))
668        t = self.thetype(s)
669        self.assertEqual(id(s), id(t))
670
671    def test_hash(self):
672        self.assertEqual(hash(self.thetype('abcdeb')),
673                         hash(self.thetype('ebecda')))
674
675        # make sure that all permutations give the same hash value
676        n = 100
677        seq = [randrange(n) for i in range(n)]
678        results = set()
679        for i in range(200):
680            shuffle(seq)
681            results.add(hash(self.thetype(seq)))
682        self.assertEqual(len(results), 1)
683
684    def test_copy(self):
685        dup = self.s.copy()
686        self.assertEqual(id(self.s), id(dup))
687
688    def test_frozen_as_dictkey(self):
689        seq = list(range(10)) + list('abcdefg') + ['apple']
690        key1 = self.thetype(seq)
691        key2 = self.thetype(reversed(seq))
692        self.assertEqual(key1, key2)
693        self.assertNotEqual(id(key1), id(key2))
694        d = {}
695        d[key1] = 42
696        self.assertEqual(d[key2], 42)
697
698    def test_hash_caching(self):
699        f = self.thetype('abcdcda')
700        self.assertEqual(hash(f), hash(f))
701
702    def test_hash_effectiveness(self):
703        n = 13
704        hashvalues = set()
705        addhashvalue = hashvalues.add
706        elemmasks = [(i+1, 1<<i) for i in range(n)]
707        for i in range(2**n):
708            addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
709        self.assertEqual(len(hashvalues), 2**n)
710
711        def zf_range(n):
712            # https://en.wikipedia.org/wiki/Set-theoretic_definition_of_natural_numbers
713            nums = [frozenset()]
714            for i in range(n-1):
715                num = frozenset(nums)
716                nums.append(num)
717            return nums[:n]
718
719        def powerset(s):
720            for i in range(len(s)+1):
721                yield from map(frozenset, itertools.combinations(s, i))
722
723        for n in range(18):
724            t = 2 ** n
725            mask = t - 1
726            for nums in (range, zf_range):
727                u = len({h & mask for h in map(hash, powerset(nums(n)))})
728                self.assertGreater(4*u, t)
729
730class FrozenSetSubclass(frozenset):
731    pass
732
733class TestFrozenSetSubclass(TestFrozenSet):
734    thetype = FrozenSetSubclass
735    basetype = frozenset
736
737    def test_constructor_identity(self):
738        s = self.thetype(range(3))
739        t = self.thetype(s)
740        self.assertNotEqual(id(s), id(t))
741
742    def test_copy(self):
743        dup = self.s.copy()
744        self.assertNotEqual(id(self.s), id(dup))
745
746    def test_nested_empty_constructor(self):
747        s = self.thetype()
748        t = self.thetype(s)
749        self.assertEqual(s, t)
750
751    def test_singleton_empty_frozenset(self):
752        Frozenset = self.thetype
753        f = frozenset()
754        F = Frozenset()
755        efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
756               Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
757               Frozenset(range(0)), Frozenset(Frozenset()),
758               Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)]
759        # All empty frozenset subclass instances should have different ids
760        self.assertEqual(len(set(map(id, efs))), len(efs))
761
762# Tests taken from test_sets.py =============================================
763
764empty_set = set()
765
766#==============================================================================
767
768class TestBasicOps:
769
770    def test_repr(self):
771        if self.repr is not None:
772            self.assertEqual(repr(self.set), self.repr)
773
774    def check_repr_against_values(self):
775        text = repr(self.set)
776        self.assertTrue(text.startswith('{'))
777        self.assertTrue(text.endswith('}'))
778
779        result = text[1:-1].split(', ')
780        result.sort()
781        sorted_repr_values = [repr(value) for value in self.values]
782        sorted_repr_values.sort()
783        self.assertEqual(result, sorted_repr_values)
784
785    def test_length(self):
786        self.assertEqual(len(self.set), self.length)
787
788    def test_self_equality(self):
789        self.assertEqual(self.set, self.set)
790
791    def test_equivalent_equality(self):
792        self.assertEqual(self.set, self.dup)
793
794    def test_copy(self):
795        self.assertEqual(self.set.copy(), self.dup)
796
797    def test_self_union(self):
798        result = self.set | self.set
799        self.assertEqual(result, self.dup)
800
801    def test_empty_union(self):
802        result = self.set | empty_set
803        self.assertEqual(result, self.dup)
804
805    def test_union_empty(self):
806        result = empty_set | self.set
807        self.assertEqual(result, self.dup)
808
809    def test_self_intersection(self):
810        result = self.set & self.set
811        self.assertEqual(result, self.dup)
812
813    def test_empty_intersection(self):
814        result = self.set & empty_set
815        self.assertEqual(result, empty_set)
816
817    def test_intersection_empty(self):
818        result = empty_set & self.set
819        self.assertEqual(result, empty_set)
820
821    def test_self_isdisjoint(self):
822        result = self.set.isdisjoint(self.set)
823        self.assertEqual(result, not self.set)
824
825    def test_empty_isdisjoint(self):
826        result = self.set.isdisjoint(empty_set)
827        self.assertEqual(result, True)
828
829    def test_isdisjoint_empty(self):
830        result = empty_set.isdisjoint(self.set)
831        self.assertEqual(result, True)
832
833    def test_self_symmetric_difference(self):
834        result = self.set ^ self.set
835        self.assertEqual(result, empty_set)
836
837    def test_empty_symmetric_difference(self):
838        result = self.set ^ empty_set
839        self.assertEqual(result, self.set)
840
841    def test_self_difference(self):
842        result = self.set - self.set
843        self.assertEqual(result, empty_set)
844
845    def test_empty_difference(self):
846        result = self.set - empty_set
847        self.assertEqual(result, self.dup)
848
849    def test_empty_difference_rev(self):
850        result = empty_set - self.set
851        self.assertEqual(result, empty_set)
852
853    def test_iteration(self):
854        for v in self.set:
855            self.assertIn(v, self.values)
856        setiter = iter(self.set)
857        self.assertEqual(setiter.__length_hint__(), len(self.set))
858
859    def test_pickling(self):
860        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
861            p = pickle.dumps(self.set, proto)
862            copy = pickle.loads(p)
863            self.assertEqual(self.set, copy,
864                             "%s != %s" % (self.set, copy))
865
866    def test_issue_37219(self):
867        with self.assertRaises(TypeError):
868            set().difference(123)
869        with self.assertRaises(TypeError):
870            set().difference_update(123)
871
872#------------------------------------------------------------------------------
873
874class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
875    def setUp(self):
876        self.case   = "empty set"
877        self.values = []
878        self.set    = set(self.values)
879        self.dup    = set(self.values)
880        self.length = 0
881        self.repr   = "set()"
882
883#------------------------------------------------------------------------------
884
885class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
886    def setUp(self):
887        self.case   = "unit set (number)"
888        self.values = [3]
889        self.set    = set(self.values)
890        self.dup    = set(self.values)
891        self.length = 1
892        self.repr   = "{3}"
893
894    def test_in(self):
895        self.assertIn(3, self.set)
896
897    def test_not_in(self):
898        self.assertNotIn(2, self.set)
899
900#------------------------------------------------------------------------------
901
902class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
903    def setUp(self):
904        self.case   = "unit set (tuple)"
905        self.values = [(0, "zero")]
906        self.set    = set(self.values)
907        self.dup    = set(self.values)
908        self.length = 1
909        self.repr   = "{(0, 'zero')}"
910
911    def test_in(self):
912        self.assertIn((0, "zero"), self.set)
913
914    def test_not_in(self):
915        self.assertNotIn(9, self.set)
916
917#------------------------------------------------------------------------------
918
919class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
920    def setUp(self):
921        self.case   = "triple set"
922        self.values = [0, "zero", operator.add]
923        self.set    = set(self.values)
924        self.dup    = set(self.values)
925        self.length = 3
926        self.repr   = None
927
928#------------------------------------------------------------------------------
929
930class TestBasicOpsString(TestBasicOps, unittest.TestCase):
931    def setUp(self):
932        self.case   = "string set"
933        self.values = ["a", "b", "c"]
934        self.set    = set(self.values)
935        self.dup    = set(self.values)
936        self.length = 3
937
938    def test_repr(self):
939        self.check_repr_against_values()
940
941#------------------------------------------------------------------------------
942
943class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
944    def setUp(self):
945        self.case   = "bytes set"
946        self.values = [b"a", b"b", b"c"]
947        self.set    = set(self.values)
948        self.dup    = set(self.values)
949        self.length = 3
950
951    def test_repr(self):
952        self.check_repr_against_values()
953
954#------------------------------------------------------------------------------
955
956class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
957    def setUp(self):
958        self._warning_filters = warnings_helper.check_warnings()
959        self._warning_filters.__enter__()
960        warnings.simplefilter('ignore', BytesWarning)
961        self.case   = "string and bytes set"
962        self.values = ["a", "b", b"a", b"b"]
963        self.set    = set(self.values)
964        self.dup    = set(self.values)
965        self.length = 4
966
967    def tearDown(self):
968        self._warning_filters.__exit__(None, None, None)
969
970    def test_repr(self):
971        self.check_repr_against_values()
972
973#==============================================================================
974
975def baditer():
976    raise TypeError
977    yield True
978
979def gooditer():
980    yield True
981
982class TestExceptionPropagation(unittest.TestCase):
983    """SF 628246:  Set constructor should not trap iterator TypeErrors"""
984
985    def test_instanceWithException(self):
986        self.assertRaises(TypeError, set, baditer())
987
988    def test_instancesWithoutException(self):
989        # All of these iterables should load without exception.
990        set([1,2,3])
991        set((1,2,3))
992        set({'one':1, 'two':2, 'three':3})
993        set(range(3))
994        set('abc')
995        set(gooditer())
996
997    def test_changingSizeWhileIterating(self):
998        s = set([1,2,3])
999        try:
1000            for i in s:
1001                s.update([4])
1002        except RuntimeError:
1003            pass
1004        else:
1005            self.fail("no exception when changing size during iteration")
1006
1007#==============================================================================
1008
1009class TestSetOfSets(unittest.TestCase):
1010    def test_constructor(self):
1011        inner = frozenset([1])
1012        outer = set([inner])
1013        element = outer.pop()
1014        self.assertEqual(type(element), frozenset)
1015        outer.add(inner)        # Rebuild set of sets with .add method
1016        outer.remove(inner)
1017        self.assertEqual(outer, set())   # Verify that remove worked
1018        outer.discard(inner)    # Absence of KeyError indicates working fine
1019
1020#==============================================================================
1021
1022class TestBinaryOps(unittest.TestCase):
1023    def setUp(self):
1024        self.set = set((2, 4, 6))
1025
1026    def test_eq(self):              # SF bug 643115
1027        self.assertEqual(self.set, set({2:1,4:3,6:5}))
1028
1029    def test_union_subset(self):
1030        result = self.set | set([2])
1031        self.assertEqual(result, set((2, 4, 6)))
1032
1033    def test_union_superset(self):
1034        result = self.set | set([2, 4, 6, 8])
1035        self.assertEqual(result, set([2, 4, 6, 8]))
1036
1037    def test_union_overlap(self):
1038        result = self.set | set([3, 4, 5])
1039        self.assertEqual(result, set([2, 3, 4, 5, 6]))
1040
1041    def test_union_non_overlap(self):
1042        result = self.set | set([8])
1043        self.assertEqual(result, set([2, 4, 6, 8]))
1044
1045    def test_intersection_subset(self):
1046        result = self.set & set((2, 4))
1047        self.assertEqual(result, set((2, 4)))
1048
1049    def test_intersection_superset(self):
1050        result = self.set & set([2, 4, 6, 8])
1051        self.assertEqual(result, set([2, 4, 6]))
1052
1053    def test_intersection_overlap(self):
1054        result = self.set & set([3, 4, 5])
1055        self.assertEqual(result, set([4]))
1056
1057    def test_intersection_non_overlap(self):
1058        result = self.set & set([8])
1059        self.assertEqual(result, empty_set)
1060
1061    def test_isdisjoint_subset(self):
1062        result = self.set.isdisjoint(set((2, 4)))
1063        self.assertEqual(result, False)
1064
1065    def test_isdisjoint_superset(self):
1066        result = self.set.isdisjoint(set([2, 4, 6, 8]))
1067        self.assertEqual(result, False)
1068
1069    def test_isdisjoint_overlap(self):
1070        result = self.set.isdisjoint(set([3, 4, 5]))
1071        self.assertEqual(result, False)
1072
1073    def test_isdisjoint_non_overlap(self):
1074        result = self.set.isdisjoint(set([8]))
1075        self.assertEqual(result, True)
1076
1077    def test_sym_difference_subset(self):
1078        result = self.set ^ set((2, 4))
1079        self.assertEqual(result, set([6]))
1080
1081    def test_sym_difference_superset(self):
1082        result = self.set ^ set((2, 4, 6, 8))
1083        self.assertEqual(result, set([8]))
1084
1085    def test_sym_difference_overlap(self):
1086        result = self.set ^ set((3, 4, 5))
1087        self.assertEqual(result, set([2, 3, 5, 6]))
1088
1089    def test_sym_difference_non_overlap(self):
1090        result = self.set ^ set([8])
1091        self.assertEqual(result, set([2, 4, 6, 8]))
1092
1093#==============================================================================
1094
1095class TestUpdateOps(unittest.TestCase):
1096    def setUp(self):
1097        self.set = set((2, 4, 6))
1098
1099    def test_union_subset(self):
1100        self.set |= set([2])
1101        self.assertEqual(self.set, set((2, 4, 6)))
1102
1103    def test_union_superset(self):
1104        self.set |= set([2, 4, 6, 8])
1105        self.assertEqual(self.set, set([2, 4, 6, 8]))
1106
1107    def test_union_overlap(self):
1108        self.set |= set([3, 4, 5])
1109        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1110
1111    def test_union_non_overlap(self):
1112        self.set |= set([8])
1113        self.assertEqual(self.set, set([2, 4, 6, 8]))
1114
1115    def test_union_method_call(self):
1116        self.set.update(set([3, 4, 5]))
1117        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1118
1119    def test_intersection_subset(self):
1120        self.set &= set((2, 4))
1121        self.assertEqual(self.set, set((2, 4)))
1122
1123    def test_intersection_superset(self):
1124        self.set &= set([2, 4, 6, 8])
1125        self.assertEqual(self.set, set([2, 4, 6]))
1126
1127    def test_intersection_overlap(self):
1128        self.set &= set([3, 4, 5])
1129        self.assertEqual(self.set, set([4]))
1130
1131    def test_intersection_non_overlap(self):
1132        self.set &= set([8])
1133        self.assertEqual(self.set, empty_set)
1134
1135    def test_intersection_method_call(self):
1136        self.set.intersection_update(set([3, 4, 5]))
1137        self.assertEqual(self.set, set([4]))
1138
1139    def test_sym_difference_subset(self):
1140        self.set ^= set((2, 4))
1141        self.assertEqual(self.set, set([6]))
1142
1143    def test_sym_difference_superset(self):
1144        self.set ^= set((2, 4, 6, 8))
1145        self.assertEqual(self.set, set([8]))
1146
1147    def test_sym_difference_overlap(self):
1148        self.set ^= set((3, 4, 5))
1149        self.assertEqual(self.set, set([2, 3, 5, 6]))
1150
1151    def test_sym_difference_non_overlap(self):
1152        self.set ^= set([8])
1153        self.assertEqual(self.set, set([2, 4, 6, 8]))
1154
1155    def test_sym_difference_method_call(self):
1156        self.set.symmetric_difference_update(set([3, 4, 5]))
1157        self.assertEqual(self.set, set([2, 3, 5, 6]))
1158
1159    def test_difference_subset(self):
1160        self.set -= set((2, 4))
1161        self.assertEqual(self.set, set([6]))
1162
1163    def test_difference_superset(self):
1164        self.set -= set((2, 4, 6, 8))
1165        self.assertEqual(self.set, set([]))
1166
1167    def test_difference_overlap(self):
1168        self.set -= set((3, 4, 5))
1169        self.assertEqual(self.set, set([2, 6]))
1170
1171    def test_difference_non_overlap(self):
1172        self.set -= set([8])
1173        self.assertEqual(self.set, set([2, 4, 6]))
1174
1175    def test_difference_method_call(self):
1176        self.set.difference_update(set([3, 4, 5]))
1177        self.assertEqual(self.set, set([2, 6]))
1178
1179#==============================================================================
1180
1181class TestMutate(unittest.TestCase):
1182    def setUp(self):
1183        self.values = ["a", "b", "c"]
1184        self.set = set(self.values)
1185
1186    def test_add_present(self):
1187        self.set.add("c")
1188        self.assertEqual(self.set, set("abc"))
1189
1190    def test_add_absent(self):
1191        self.set.add("d")
1192        self.assertEqual(self.set, set("abcd"))
1193
1194    def test_add_until_full(self):
1195        tmp = set()
1196        expected_len = 0
1197        for v in self.values:
1198            tmp.add(v)
1199            expected_len += 1
1200            self.assertEqual(len(tmp), expected_len)
1201        self.assertEqual(tmp, self.set)
1202
1203    def test_remove_present(self):
1204        self.set.remove("b")
1205        self.assertEqual(self.set, set("ac"))
1206
1207    def test_remove_absent(self):
1208        try:
1209            self.set.remove("d")
1210            self.fail("Removing missing element should have raised LookupError")
1211        except LookupError:
1212            pass
1213
1214    def test_remove_until_empty(self):
1215        expected_len = len(self.set)
1216        for v in self.values:
1217            self.set.remove(v)
1218            expected_len -= 1
1219            self.assertEqual(len(self.set), expected_len)
1220
1221    def test_discard_present(self):
1222        self.set.discard("c")
1223        self.assertEqual(self.set, set("ab"))
1224
1225    def test_discard_absent(self):
1226        self.set.discard("d")
1227        self.assertEqual(self.set, set("abc"))
1228
1229    def test_clear(self):
1230        self.set.clear()
1231        self.assertEqual(len(self.set), 0)
1232
1233    def test_pop(self):
1234        popped = {}
1235        while self.set:
1236            popped[self.set.pop()] = None
1237        self.assertEqual(len(popped), len(self.values))
1238        for v in self.values:
1239            self.assertIn(v, popped)
1240
1241    def test_update_empty_tuple(self):
1242        self.set.update(())
1243        self.assertEqual(self.set, set(self.values))
1244
1245    def test_update_unit_tuple_overlap(self):
1246        self.set.update(("a",))
1247        self.assertEqual(self.set, set(self.values))
1248
1249    def test_update_unit_tuple_non_overlap(self):
1250        self.set.update(("a", "z"))
1251        self.assertEqual(self.set, set(self.values + ["z"]))
1252
1253#==============================================================================
1254
1255class TestSubsets:
1256
1257    case2method = {"<=": "issubset",
1258                   ">=": "issuperset",
1259                  }
1260
1261    reverse = {"==": "==",
1262               "!=": "!=",
1263               "<":  ">",
1264               ">":  "<",
1265               "<=": ">=",
1266               ">=": "<=",
1267              }
1268
1269    def test_issubset(self):
1270        x = self.left
1271        y = self.right
1272        for case in "!=", "==", "<", "<=", ">", ">=":
1273            expected = case in self.cases
1274            # Test the binary infix spelling.
1275            result = eval("x" + case + "y", locals())
1276            self.assertEqual(result, expected)
1277            # Test the "friendly" method-name spelling, if one exists.
1278            if case in TestSubsets.case2method:
1279                method = getattr(x, TestSubsets.case2method[case])
1280                result = method(y)
1281                self.assertEqual(result, expected)
1282
1283            # Now do the same for the operands reversed.
1284            rcase = TestSubsets.reverse[case]
1285            result = eval("y" + rcase + "x", locals())
1286            self.assertEqual(result, expected)
1287            if rcase in TestSubsets.case2method:
1288                method = getattr(y, TestSubsets.case2method[rcase])
1289                result = method(x)
1290                self.assertEqual(result, expected)
1291#------------------------------------------------------------------------------
1292
1293class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
1294    left  = set()
1295    right = set()
1296    name  = "both empty"
1297    cases = "==", "<=", ">="
1298
1299#------------------------------------------------------------------------------
1300
1301class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
1302    left  = set([1, 2])
1303    right = set([1, 2])
1304    name  = "equal pair"
1305    cases = "==", "<=", ">="
1306
1307#------------------------------------------------------------------------------
1308
1309class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
1310    left  = set()
1311    right = set([1, 2])
1312    name  = "one empty, one non-empty"
1313    cases = "!=", "<", "<="
1314
1315#------------------------------------------------------------------------------
1316
1317class TestSubsetPartial(TestSubsets, unittest.TestCase):
1318    left  = set([1])
1319    right = set([1, 2])
1320    name  = "one a non-empty proper subset of other"
1321    cases = "!=", "<", "<="
1322
1323#------------------------------------------------------------------------------
1324
1325class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
1326    left  = set([1])
1327    right = set([2])
1328    name  = "neither empty, neither contains"
1329    cases = "!="
1330
1331#==============================================================================
1332
1333class TestOnlySetsInBinaryOps:
1334
1335    def test_eq_ne(self):
1336        # Unlike the others, this is testing that == and != *are* allowed.
1337        self.assertEqual(self.other == self.set, False)
1338        self.assertEqual(self.set == self.other, False)
1339        self.assertEqual(self.other != self.set, True)
1340        self.assertEqual(self.set != self.other, True)
1341
1342    def test_ge_gt_le_lt(self):
1343        self.assertRaises(TypeError, lambda: self.set < self.other)
1344        self.assertRaises(TypeError, lambda: self.set <= self.other)
1345        self.assertRaises(TypeError, lambda: self.set > self.other)
1346        self.assertRaises(TypeError, lambda: self.set >= self.other)
1347
1348        self.assertRaises(TypeError, lambda: self.other < self.set)
1349        self.assertRaises(TypeError, lambda: self.other <= self.set)
1350        self.assertRaises(TypeError, lambda: self.other > self.set)
1351        self.assertRaises(TypeError, lambda: self.other >= self.set)
1352
1353    def test_update_operator(self):
1354        try:
1355            self.set |= self.other
1356        except TypeError:
1357            pass
1358        else:
1359            self.fail("expected TypeError")
1360
1361    def test_update(self):
1362        if self.otherIsIterable:
1363            self.set.update(self.other)
1364        else:
1365            self.assertRaises(TypeError, self.set.update, self.other)
1366
1367    def test_union(self):
1368        self.assertRaises(TypeError, lambda: self.set | self.other)
1369        self.assertRaises(TypeError, lambda: self.other | self.set)
1370        if self.otherIsIterable:
1371            self.set.union(self.other)
1372        else:
1373            self.assertRaises(TypeError, self.set.union, self.other)
1374
1375    def test_intersection_update_operator(self):
1376        try:
1377            self.set &= self.other
1378        except TypeError:
1379            pass
1380        else:
1381            self.fail("expected TypeError")
1382
1383    def test_intersection_update(self):
1384        if self.otherIsIterable:
1385            self.set.intersection_update(self.other)
1386        else:
1387            self.assertRaises(TypeError,
1388                              self.set.intersection_update,
1389                              self.other)
1390
1391    def test_intersection(self):
1392        self.assertRaises(TypeError, lambda: self.set & self.other)
1393        self.assertRaises(TypeError, lambda: self.other & self.set)
1394        if self.otherIsIterable:
1395            self.set.intersection(self.other)
1396        else:
1397            self.assertRaises(TypeError, self.set.intersection, self.other)
1398
1399    def test_sym_difference_update_operator(self):
1400        try:
1401            self.set ^= self.other
1402        except TypeError:
1403            pass
1404        else:
1405            self.fail("expected TypeError")
1406
1407    def test_sym_difference_update(self):
1408        if self.otherIsIterable:
1409            self.set.symmetric_difference_update(self.other)
1410        else:
1411            self.assertRaises(TypeError,
1412                              self.set.symmetric_difference_update,
1413                              self.other)
1414
1415    def test_sym_difference(self):
1416        self.assertRaises(TypeError, lambda: self.set ^ self.other)
1417        self.assertRaises(TypeError, lambda: self.other ^ self.set)
1418        if self.otherIsIterable:
1419            self.set.symmetric_difference(self.other)
1420        else:
1421            self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
1422
1423    def test_difference_update_operator(self):
1424        try:
1425            self.set -= self.other
1426        except TypeError:
1427            pass
1428        else:
1429            self.fail("expected TypeError")
1430
1431    def test_difference_update(self):
1432        if self.otherIsIterable:
1433            self.set.difference_update(self.other)
1434        else:
1435            self.assertRaises(TypeError,
1436                              self.set.difference_update,
1437                              self.other)
1438
1439    def test_difference(self):
1440        self.assertRaises(TypeError, lambda: self.set - self.other)
1441        self.assertRaises(TypeError, lambda: self.other - self.set)
1442        if self.otherIsIterable:
1443            self.set.difference(self.other)
1444        else:
1445            self.assertRaises(TypeError, self.set.difference, self.other)
1446
1447#------------------------------------------------------------------------------
1448
1449class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
1450    def setUp(self):
1451        self.set   = set((1, 2, 3))
1452        self.other = 19
1453        self.otherIsIterable = False
1454
1455#------------------------------------------------------------------------------
1456
1457class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
1458    def setUp(self):
1459        self.set   = set((1, 2, 3))
1460        self.other = {1:2, 3:4}
1461        self.otherIsIterable = True
1462
1463#------------------------------------------------------------------------------
1464
1465class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
1466    def setUp(self):
1467        self.set   = set((1, 2, 3))
1468        self.other = operator.add
1469        self.otherIsIterable = False
1470
1471#------------------------------------------------------------------------------
1472
1473class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
1474    def setUp(self):
1475        self.set   = set((1, 2, 3))
1476        self.other = (2, 4, 6)
1477        self.otherIsIterable = True
1478
1479#------------------------------------------------------------------------------
1480
1481class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
1482    def setUp(self):
1483        self.set   = set((1, 2, 3))
1484        self.other = 'abc'
1485        self.otherIsIterable = True
1486
1487#------------------------------------------------------------------------------
1488
1489class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
1490    def setUp(self):
1491        def gen():
1492            for i in range(0, 10, 2):
1493                yield i
1494        self.set   = set((1, 2, 3))
1495        self.other = gen()
1496        self.otherIsIterable = True
1497
1498#==============================================================================
1499
1500class TestCopying:
1501
1502    def test_copy(self):
1503        dup = self.set.copy()
1504        dup_list = sorted(dup, key=repr)
1505        set_list = sorted(self.set, key=repr)
1506        self.assertEqual(len(dup_list), len(set_list))
1507        for i in range(len(dup_list)):
1508            self.assertTrue(dup_list[i] is set_list[i])
1509
1510    def test_deep_copy(self):
1511        dup = copy.deepcopy(self.set)
1512        ##print type(dup), repr(dup)
1513        dup_list = sorted(dup, key=repr)
1514        set_list = sorted(self.set, key=repr)
1515        self.assertEqual(len(dup_list), len(set_list))
1516        for i in range(len(dup_list)):
1517            self.assertEqual(dup_list[i], set_list[i])
1518
1519#------------------------------------------------------------------------------
1520
1521class TestCopyingEmpty(TestCopying, unittest.TestCase):
1522    def setUp(self):
1523        self.set = set()
1524
1525#------------------------------------------------------------------------------
1526
1527class TestCopyingSingleton(TestCopying, unittest.TestCase):
1528    def setUp(self):
1529        self.set = set(["hello"])
1530
1531#------------------------------------------------------------------------------
1532
1533class TestCopyingTriple(TestCopying, unittest.TestCase):
1534    def setUp(self):
1535        self.set = set(["zero", 0, None])
1536
1537#------------------------------------------------------------------------------
1538
1539class TestCopyingTuple(TestCopying, unittest.TestCase):
1540    def setUp(self):
1541        self.set = set([(1, 2)])
1542
1543#------------------------------------------------------------------------------
1544
1545class TestCopyingNested(TestCopying, unittest.TestCase):
1546    def setUp(self):
1547        self.set = set([((1, 2), (3, 4))])
1548
1549#==============================================================================
1550
1551class TestIdentities(unittest.TestCase):
1552    def setUp(self):
1553        self.a = set('abracadabra')
1554        self.b = set('alacazam')
1555
1556    def test_binopsVsSubsets(self):
1557        a, b = self.a, self.b
1558        self.assertTrue(a - b < a)
1559        self.assertTrue(b - a < b)
1560        self.assertTrue(a & b < a)
1561        self.assertTrue(a & b < b)
1562        self.assertTrue(a | b > a)
1563        self.assertTrue(a | b > b)
1564        self.assertTrue(a ^ b < a | b)
1565
1566    def test_commutativity(self):
1567        a, b = self.a, self.b
1568        self.assertEqual(a&b, b&a)
1569        self.assertEqual(a|b, b|a)
1570        self.assertEqual(a^b, b^a)
1571        if a != b:
1572            self.assertNotEqual(a-b, b-a)
1573
1574    def test_summations(self):
1575        # check that sums of parts equal the whole
1576        a, b = self.a, self.b
1577        self.assertEqual((a-b)|(a&b)|(b-a), a|b)
1578        self.assertEqual((a&b)|(a^b), a|b)
1579        self.assertEqual(a|(b-a), a|b)
1580        self.assertEqual((a-b)|b, a|b)
1581        self.assertEqual((a-b)|(a&b), a)
1582        self.assertEqual((b-a)|(a&b), b)
1583        self.assertEqual((a-b)|(b-a), a^b)
1584
1585    def test_exclusion(self):
1586        # check that inverse operations show non-overlap
1587        a, b, zero = self.a, self.b, set()
1588        self.assertEqual((a-b)&b, zero)
1589        self.assertEqual((b-a)&a, zero)
1590        self.assertEqual((a&b)&(a^b), zero)
1591
1592# Tests derived from test_itertools.py =======================================
1593
1594def R(seqn):
1595    'Regular generator'
1596    for i in seqn:
1597        yield i
1598
1599class G:
1600    'Sequence using __getitem__'
1601    def __init__(self, seqn):
1602        self.seqn = seqn
1603    def __getitem__(self, i):
1604        return self.seqn[i]
1605
1606class I:
1607    'Sequence using iterator protocol'
1608    def __init__(self, seqn):
1609        self.seqn = seqn
1610        self.i = 0
1611    def __iter__(self):
1612        return self
1613    def __next__(self):
1614        if self.i >= len(self.seqn): raise StopIteration
1615        v = self.seqn[self.i]
1616        self.i += 1
1617        return v
1618
1619class Ig:
1620    'Sequence using iterator protocol defined with a generator'
1621    def __init__(self, seqn):
1622        self.seqn = seqn
1623        self.i = 0
1624    def __iter__(self):
1625        for val in self.seqn:
1626            yield val
1627
1628class X:
1629    'Missing __getitem__ and __iter__'
1630    def __init__(self, seqn):
1631        self.seqn = seqn
1632        self.i = 0
1633    def __next__(self):
1634        if self.i >= len(self.seqn): raise StopIteration
1635        v = self.seqn[self.i]
1636        self.i += 1
1637        return v
1638
1639class N:
1640    'Iterator missing __next__()'
1641    def __init__(self, seqn):
1642        self.seqn = seqn
1643        self.i = 0
1644    def __iter__(self):
1645        return self
1646
1647class E:
1648    'Test propagation of exceptions'
1649    def __init__(self, seqn):
1650        self.seqn = seqn
1651        self.i = 0
1652    def __iter__(self):
1653        return self
1654    def __next__(self):
1655        3 // 0
1656
1657class S:
1658    'Test immediate stop'
1659    def __init__(self, seqn):
1660        pass
1661    def __iter__(self):
1662        return self
1663    def __next__(self):
1664        raise StopIteration
1665
1666from itertools import chain
1667def L(seqn):
1668    'Test multiple tiers of iterators'
1669    return chain(map(lambda x:x, R(Ig(G(seqn)))))
1670
1671class TestVariousIteratorArgs(unittest.TestCase):
1672
1673    def test_constructor(self):
1674        for cons in (set, frozenset):
1675            for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
1676                for g in (G, I, Ig, S, L, R):
1677                    self.assertEqual(sorted(cons(g(s)), key=repr), sorted(g(s), key=repr))
1678                self.assertRaises(TypeError, cons , X(s))
1679                self.assertRaises(TypeError, cons , N(s))
1680                self.assertRaises(ZeroDivisionError, cons , E(s))
1681
1682    def test_inline_methods(self):
1683        s = set('november')
1684        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1685            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
1686                for g in (G, I, Ig, L, R):
1687                    expected = meth(data)
1688                    actual = meth(g(data))
1689                    if isinstance(expected, bool):
1690                        self.assertEqual(actual, expected)
1691                    else:
1692                        self.assertEqual(sorted(actual, key=repr), sorted(expected, key=repr))
1693                self.assertRaises(TypeError, meth, X(s))
1694                self.assertRaises(TypeError, meth, N(s))
1695                self.assertRaises(ZeroDivisionError, meth, E(s))
1696
1697    def test_inplace_methods(self):
1698        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1699            for methname in ('update', 'intersection_update',
1700                             'difference_update', 'symmetric_difference_update'):
1701                for g in (G, I, Ig, S, L, R):
1702                    s = set('january')
1703                    t = s.copy()
1704                    getattr(s, methname)(list(g(data)))
1705                    getattr(t, methname)(g(data))
1706                    self.assertEqual(sorted(s, key=repr), sorted(t, key=repr))
1707
1708                self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
1709                self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
1710                self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
1711
1712class bad_eq:
1713    def __eq__(self, other):
1714        if be_bad:
1715            set2.clear()
1716            raise ZeroDivisionError
1717        return self is other
1718    def __hash__(self):
1719        return 0
1720
1721class bad_dict_clear:
1722    def __eq__(self, other):
1723        if be_bad:
1724            dict2.clear()
1725        return self is other
1726    def __hash__(self):
1727        return 0
1728
1729class TestWeirdBugs(unittest.TestCase):
1730    def test_8420_set_merge(self):
1731        # This used to segfault
1732        global be_bad, set2, dict2
1733        be_bad = False
1734        set1 = {bad_eq()}
1735        set2 = {bad_eq() for i in range(75)}
1736        be_bad = True
1737        self.assertRaises(ZeroDivisionError, set1.update, set2)
1738
1739        be_bad = False
1740        set1 = {bad_dict_clear()}
1741        dict2 = {bad_dict_clear(): None}
1742        be_bad = True
1743        set1.symmetric_difference_update(dict2)
1744
1745    def test_iter_and_mutate(self):
1746        # Issue #24581
1747        s = set(range(100))
1748        s.clear()
1749        s.update(range(100))
1750        si = iter(s)
1751        s.clear()
1752        a = list(range(100))
1753        s.update(range(100))
1754        list(si)
1755
1756    def test_merge_and_mutate(self):
1757        class X:
1758            def __hash__(self):
1759                return hash(0)
1760            def __eq__(self, o):
1761                other.clear()
1762                return False
1763
1764        other = set()
1765        other = {X() for i in range(10)}
1766        s = {0}
1767        s.update(other)
1768
1769# Application tests (based on David Eppstein's graph recipes ====================================
1770
1771def powerset(U):
1772    """Generates all subsets of a set or sequence U."""
1773    U = iter(U)
1774    try:
1775        x = frozenset([next(U)])
1776        for S in powerset(U):
1777            yield S
1778            yield S | x
1779    except StopIteration:
1780        yield frozenset()
1781
1782def cube(n):
1783    """Graph of n-dimensional hypercube."""
1784    singletons = [frozenset([x]) for x in range(n)]
1785    return dict([(x, frozenset([x^s for s in singletons]))
1786                 for x in powerset(range(n))])
1787
1788def linegraph(G):
1789    """Graph, the vertices of which are edges of G,
1790    with two vertices being adjacent iff the corresponding
1791    edges share a vertex."""
1792    L = {}
1793    for x in G:
1794        for y in G[x]:
1795            nx = [frozenset([x,z]) for z in G[x] if z != y]
1796            ny = [frozenset([y,z]) for z in G[y] if z != x]
1797            L[frozenset([x,y])] = frozenset(nx+ny)
1798    return L
1799
1800def faces(G):
1801    'Return a set of faces in G.  Where a face is a set of vertices on that face'
1802    # currently limited to triangles,squares, and pentagons
1803    f = set()
1804    for v1, edges in G.items():
1805        for v2 in edges:
1806            for v3 in G[v2]:
1807                if v1 == v3:
1808                    continue
1809                if v1 in G[v3]:
1810                    f.add(frozenset([v1, v2, v3]))
1811                else:
1812                    for v4 in G[v3]:
1813                        if v4 == v2:
1814                            continue
1815                        if v1 in G[v4]:
1816                            f.add(frozenset([v1, v2, v3, v4]))
1817                        else:
1818                            for v5 in G[v4]:
1819                                if v5 == v3 or v5 == v2:
1820                                    continue
1821                                if v1 in G[v5]:
1822                                    f.add(frozenset([v1, v2, v3, v4, v5]))
1823    return f
1824
1825
1826class TestGraphs(unittest.TestCase):
1827
1828    def test_cube(self):
1829
1830        g = cube(3)                             # vert --> {v1, v2, v3}
1831        vertices1 = set(g)
1832        self.assertEqual(len(vertices1), 8)     # eight vertices
1833        for edge in g.values():
1834            self.assertEqual(len(edge), 3)      # each vertex connects to three edges
1835        vertices2 = set(v for edges in g.values() for v in edges)
1836        self.assertEqual(vertices1, vertices2)  # edge vertices in original set
1837
1838        cubefaces = faces(g)
1839        self.assertEqual(len(cubefaces), 6)     # six faces
1840        for face in cubefaces:
1841            self.assertEqual(len(face), 4)      # each face is a square
1842
1843    def test_cuboctahedron(self):
1844
1845        # http://en.wikipedia.org/wiki/Cuboctahedron
1846        # 8 triangular faces and 6 square faces
1847        # 12 identical vertices each connecting a triangle and square
1848
1849        g = cube(3)
1850        cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
1851        self.assertEqual(len(cuboctahedron), 12)# twelve vertices
1852
1853        vertices = set(cuboctahedron)
1854        for edges in cuboctahedron.values():
1855            self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
1856        othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
1857        self.assertEqual(vertices, othervertices)   # edge vertices in original set
1858
1859        cubofaces = faces(cuboctahedron)
1860        facesizes = collections.defaultdict(int)
1861        for face in cubofaces:
1862            facesizes[len(face)] += 1
1863        self.assertEqual(facesizes[3], 8)       # eight triangular faces
1864        self.assertEqual(facesizes[4], 6)       # six square faces
1865
1866        for vertex in cuboctahedron:
1867            edge = vertex                       # Cuboctahedron vertices are edges in Cube
1868            self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
1869            for cubevert in edge:
1870                self.assertIn(cubevert, g)
1871
1872
1873#==============================================================================
1874
1875if __name__ == "__main__":
1876    unittest.main()
1877