• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2import unittest
3from test import test_support
4import gc
5import weakref
6import operator
7import copy
8import pickle
9from random import randrange, shuffle
10import sys
11import collections
12
13class PassThru(Exception):
14    pass
15
16def check_pass_thru():
17    raise PassThru
18    yield 1
19
20class BadCmp:
21    def __hash__(self):
22        return 1
23    def __cmp__(self, other):
24        raise RuntimeError
25
26class ReprWrapper:
27    'Used to test self-referential repr() calls'
28    def __repr__(self):
29        return repr(self.value)
30
31class HashCountingInt(int):
32    'int-like object that counts the number of times __hash__ is called'
33    def __init__(self, *args):
34        self.hash_count = 0
35    def __hash__(self):
36        self.hash_count += 1
37        return int.__hash__(self)
38
39class TestJointOps(unittest.TestCase):
40    # Tests common to both set and frozenset
41
42    def setUp(self):
43        self.word = word = 'simsalabim'
44        self.otherword = 'madagascar'
45        self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
46        self.s = self.thetype(word)
47        self.d = dict.fromkeys(word)
48
49    def test_new_or_init(self):
50        self.assertRaises(TypeError, self.thetype, [], 2)
51        self.assertRaises(TypeError, set().__init__, a=1)
52
53    def test_uniquification(self):
54        actual = sorted(self.s)
55        expected = sorted(self.d)
56        self.assertEqual(actual, expected)
57        self.assertRaises(PassThru, self.thetype, check_pass_thru())
58        self.assertRaises(TypeError, self.thetype, [[]])
59
60    def test_len(self):
61        self.assertEqual(len(self.s), len(self.d))
62
63    def test_contains(self):
64        for c in self.letters:
65            self.assertEqual(c in self.s, c in self.d)
66        self.assertRaises(TypeError, self.s.__contains__, [[]])
67        s = self.thetype([frozenset(self.letters)])
68        self.assertIn(self.thetype(self.letters), s)
69
70    def test_union(self):
71        u = self.s.union(self.otherword)
72        for c in self.letters:
73            self.assertEqual(c in u, c in self.d or c in self.otherword)
74        self.assertEqual(self.s, self.thetype(self.word))
75        self.assertEqual(type(u), self.thetype)
76        self.assertRaises(PassThru, self.s.union, check_pass_thru())
77        self.assertRaises(TypeError, self.s.union, [[]])
78        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
79            self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
80            self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
81            self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
82            self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
83            self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
84
85        # Issue #6573
86        x = self.thetype()
87        self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
88
89    def test_or(self):
90        i = self.s.union(self.otherword)
91        self.assertEqual(self.s | set(self.otherword), i)
92        self.assertEqual(self.s | frozenset(self.otherword), i)
93        try:
94            self.s | self.otherword
95        except TypeError:
96            pass
97        else:
98            self.fail("s|t did not screen-out general iterables")
99
100    def test_intersection(self):
101        i = self.s.intersection(self.otherword)
102        for c in self.letters:
103            self.assertEqual(c in i, c in self.d and c in self.otherword)
104        self.assertEqual(self.s, self.thetype(self.word))
105        self.assertEqual(type(i), self.thetype)
106        self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
107        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
108            self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
109            self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
110            self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
111            self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
112            self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
113        s = self.thetype('abcba')
114        z = s.intersection()
115        if self.thetype == frozenset():
116            self.assertEqual(id(s), id(z))
117        else:
118            self.assertNotEqual(id(s), id(z))
119
120    def test_isdisjoint(self):
121        def f(s1, s2):
122            'Pure python equivalent of isdisjoint()'
123            return not set(s1).intersection(s2)
124        for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
125            s1 = self.thetype(larg)
126            for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
127                for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
128                    s2 = C(rarg)
129                    actual = s1.isdisjoint(s2)
130                    expected = f(s1, s2)
131                    self.assertEqual(actual, expected)
132                    self.assertTrue(actual is True or actual is False)
133
134    def test_and(self):
135        i = self.s.intersection(self.otherword)
136        self.assertEqual(self.s & set(self.otherword), i)
137        self.assertEqual(self.s & frozenset(self.otherword), i)
138        try:
139            self.s & self.otherword
140        except TypeError:
141            pass
142        else:
143            self.fail("s&t did not screen-out general iterables")
144
145    def test_difference(self):
146        i = self.s.difference(self.otherword)
147        for c in self.letters:
148            self.assertEqual(c in i, c in self.d and c not in self.otherword)
149        self.assertEqual(self.s, self.thetype(self.word))
150        self.assertEqual(type(i), self.thetype)
151        self.assertRaises(PassThru, self.s.difference, check_pass_thru())
152        self.assertRaises(TypeError, self.s.difference, [[]])
153        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
154            self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
155            self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
156            self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
157            self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
158            self.assertEqual(self.thetype('abcba').difference(), set('abc'))
159            self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
160
161    def test_sub(self):
162        i = self.s.difference(self.otherword)
163        self.assertEqual(self.s - set(self.otherword), i)
164        self.assertEqual(self.s - frozenset(self.otherword), i)
165        try:
166            self.s - self.otherword
167        except TypeError:
168            pass
169        else:
170            self.fail("s-t did not screen-out general iterables")
171
172    def test_symmetric_difference(self):
173        i = self.s.symmetric_difference(self.otherword)
174        for c in self.letters:
175            self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
176        self.assertEqual(self.s, self.thetype(self.word))
177        self.assertEqual(type(i), self.thetype)
178        self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
179        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
180        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
181            self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
182            self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
183            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
184            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
185
186    def test_xor(self):
187        i = self.s.symmetric_difference(self.otherword)
188        self.assertEqual(self.s ^ set(self.otherword), i)
189        self.assertEqual(self.s ^ frozenset(self.otherword), i)
190        try:
191            self.s ^ self.otherword
192        except TypeError:
193            pass
194        else:
195            self.fail("s^t did not screen-out general iterables")
196
197    def test_equality(self):
198        self.assertEqual(self.s, set(self.word))
199        self.assertEqual(self.s, frozenset(self.word))
200        self.assertEqual(self.s == self.word, False)
201        self.assertNotEqual(self.s, set(self.otherword))
202        self.assertNotEqual(self.s, frozenset(self.otherword))
203        self.assertEqual(self.s != self.word, True)
204
205    def test_setOfFrozensets(self):
206        t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
207        s = self.thetype(t)
208        self.assertEqual(len(s), 3)
209
210    def test_compare(self):
211        self.assertRaises(TypeError, self.s.__cmp__, self.s)
212
213    def test_sub_and_super(self):
214        p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
215        self.assertTrue(p < q)
216        self.assertTrue(p <= q)
217        self.assertTrue(q <= q)
218        self.assertTrue(q > p)
219        self.assertTrue(q >= p)
220        self.assertFalse(q < r)
221        self.assertFalse(q <= r)
222        self.assertFalse(q > r)
223        self.assertFalse(q >= r)
224        self.assertTrue(set('a').issubset('abc'))
225        self.assertTrue(set('abc').issuperset('a'))
226        self.assertFalse(set('a').issubset('cbs'))
227        self.assertFalse(set('cbs').issuperset('a'))
228
229    def test_pickling(self):
230        for i in range(pickle.HIGHEST_PROTOCOL + 1):
231            p = pickle.dumps(self.s, i)
232            dup = pickle.loads(p)
233            self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
234            if type(self.s) not in (set, frozenset):
235                self.s.x = 10
236                p = pickle.dumps(self.s, i)
237                dup = pickle.loads(p)
238                self.assertEqual(self.s.x, dup.x)
239
240    def test_deepcopy(self):
241        class Tracer:
242            def __init__(self, value):
243                self.value = value
244            def __hash__(self):
245                return self.value
246            def __deepcopy__(self, memo=None):
247                return Tracer(self.value + 1)
248        t = Tracer(10)
249        s = self.thetype([t])
250        dup = copy.deepcopy(s)
251        self.assertNotEqual(id(s), id(dup))
252        for elem in dup:
253            newt = elem
254        self.assertNotEqual(id(t), id(newt))
255        self.assertEqual(t.value + 1, newt.value)
256
257    def test_gc(self):
258        # Create a nest of cycles to exercise overall ref count check
259        class A:
260            pass
261        s = set(A() for i in xrange(1000))
262        for elem in s:
263            elem.cycle = s
264            elem.sub = elem
265            elem.set = set([elem])
266
267    def test_subclass_with_custom_hash(self):
268        # Bug #1257731
269        class H(self.thetype):
270            def __hash__(self):
271                return int(id(self) & 0x7fffffff)
272        s=H()
273        f=set()
274        f.add(s)
275        self.assertIn(s, f)
276        f.remove(s)
277        f.add(s)
278        f.discard(s)
279
280    def test_badcmp(self):
281        s = self.thetype([BadCmp()])
282        # Detect comparison errors during insertion and lookup
283        self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()])
284        self.assertRaises(RuntimeError, s.__contains__, BadCmp())
285        # Detect errors during mutating operations
286        if hasattr(s, 'add'):
287            self.assertRaises(RuntimeError, s.add, BadCmp())
288            self.assertRaises(RuntimeError, s.discard, BadCmp())
289            self.assertRaises(RuntimeError, s.remove, BadCmp())
290
291    def test_cyclical_repr(self):
292        w = ReprWrapper()
293        s = self.thetype([w])
294        w.value = s
295        name = repr(s).partition('(')[0]    # strip class name from repr string
296        self.assertEqual(repr(s), '%s([%s(...)])' % (name, name))
297
298    def test_cyclical_print(self):
299        w = ReprWrapper()
300        s = self.thetype([w])
301        w.value = s
302        fo = open(test_support.TESTFN, "wb")
303        try:
304            print >> fo, s,
305            fo.close()
306            fo = open(test_support.TESTFN, "rb")
307            self.assertEqual(fo.read(), repr(s))
308        finally:
309            fo.close()
310            test_support.unlink(test_support.TESTFN)
311
312    def test_do_not_rehash_dict_keys(self):
313        n = 10
314        d = dict.fromkeys(map(HashCountingInt, xrange(n)))
315        self.assertEqual(sum(elem.hash_count for elem in d), n)
316        s = self.thetype(d)
317        self.assertEqual(sum(elem.hash_count for elem in d), n)
318        s.difference(d)
319        self.assertEqual(sum(elem.hash_count for elem in d), n)
320        if hasattr(s, 'symmetric_difference_update'):
321            s.symmetric_difference_update(d)
322        self.assertEqual(sum(elem.hash_count for elem in d), n)
323        d2 = dict.fromkeys(set(d))
324        self.assertEqual(sum(elem.hash_count for elem in d), n)
325        d3 = dict.fromkeys(frozenset(d))
326        self.assertEqual(sum(elem.hash_count for elem in d), n)
327        d3 = dict.fromkeys(frozenset(d), 123)
328        self.assertEqual(sum(elem.hash_count for elem in d), n)
329        self.assertEqual(d3, dict.fromkeys(d, 123))
330
331    def test_container_iterator(self):
332        # Bug #3680: tp_traverse was not implemented for set iterator object
333        class C(object):
334            pass
335        obj = C()
336        ref = weakref.ref(obj)
337        container = set([obj, 1])
338        obj.x = iter(container)
339        del obj, container
340        gc.collect()
341        self.assertTrue(ref() is None, "Cycle was not collected")
342
343    def test_free_after_iterating(self):
344        test_support.check_free_after_iterating(self, iter, self.thetype)
345
346class TestSet(TestJointOps):
347    thetype = set
348
349    def test_init(self):
350        s = self.thetype()
351        s.__init__(self.word)
352        self.assertEqual(s, set(self.word))
353        s.__init__(self.otherword)
354        self.assertEqual(s, set(self.otherword))
355        self.assertRaises(TypeError, s.__init__, s, 2);
356        self.assertRaises(TypeError, s.__init__, 1);
357
358    def test_constructor_identity(self):
359        s = self.thetype(range(3))
360        t = self.thetype(s)
361        self.assertNotEqual(id(s), id(t))
362
363    def test_set_literal_insertion_order(self):
364        # SF Issue #26020 -- Expect left to right insertion
365        s = {1, 1.0, True}
366        self.assertEqual(len(s), 1)
367        stored_value = s.pop()
368        self.assertEqual(type(stored_value), int)
369
370    def test_set_literal_evaluation_order(self):
371        # Expect left to right expression evaluation
372        events = []
373        def record(obj):
374            events.append(obj)
375        s = {record(1), record(2), record(3)}
376        self.assertEqual(events, [1, 2, 3])
377
378    def test_hash(self):
379        self.assertRaises(TypeError, hash, self.s)
380
381    def test_clear(self):
382        self.s.clear()
383        self.assertEqual(self.s, set())
384        self.assertEqual(len(self.s), 0)
385
386    def test_copy(self):
387        dup = self.s.copy()
388        self.assertEqual(self.s, dup)
389        self.assertNotEqual(id(self.s), id(dup))
390
391    def test_add(self):
392        self.s.add('Q')
393        self.assertIn('Q', self.s)
394        dup = self.s.copy()
395        self.s.add('Q')
396        self.assertEqual(self.s, dup)
397        self.assertRaises(TypeError, self.s.add, [])
398
399    def test_remove(self):
400        self.s.remove('a')
401        self.assertNotIn('a', self.s)
402        self.assertRaises(KeyError, self.s.remove, 'Q')
403        self.assertRaises(TypeError, self.s.remove, [])
404        s = self.thetype([frozenset(self.word)])
405        self.assertIn(self.thetype(self.word), s)
406        s.remove(self.thetype(self.word))
407        self.assertNotIn(self.thetype(self.word), s)
408        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
409
410    def test_remove_keyerror_unpacking(self):
411        # bug:  www.python.org/sf/1576657
412        for v1 in ['Q', (1,)]:
413            try:
414                self.s.remove(v1)
415            except KeyError, e:
416                v2 = e.args[0]
417                self.assertEqual(v1, v2)
418            else:
419                self.fail()
420
421    def test_remove_keyerror_set(self):
422        key = self.thetype([3, 4])
423        try:
424            self.s.remove(key)
425        except KeyError as e:
426            self.assertTrue(e.args[0] is key,
427                         "KeyError should be {0}, not {1}".format(key,
428                                                                  e.args[0]))
429        else:
430            self.fail()
431
432    def test_discard(self):
433        self.s.discard('a')
434        self.assertNotIn('a', self.s)
435        self.s.discard('Q')
436        self.assertRaises(TypeError, self.s.discard, [])
437        s = self.thetype([frozenset(self.word)])
438        self.assertIn(self.thetype(self.word), s)
439        s.discard(self.thetype(self.word))
440        self.assertNotIn(self.thetype(self.word), s)
441        s.discard(self.thetype(self.word))
442
443    def test_pop(self):
444        for i in xrange(len(self.s)):
445            elem = self.s.pop()
446            self.assertNotIn(elem, self.s)
447        self.assertRaises(KeyError, self.s.pop)
448
449    def test_update(self):
450        retval = self.s.update(self.otherword)
451        self.assertEqual(retval, None)
452        for c in (self.word + self.otherword):
453            self.assertIn(c, self.s)
454        self.assertRaises(PassThru, self.s.update, check_pass_thru())
455        self.assertRaises(TypeError, self.s.update, [[]])
456        for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
457            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
458                s = self.thetype('abcba')
459                self.assertEqual(s.update(C(p)), None)
460                self.assertEqual(s, set(q))
461        for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
462            q = 'ahi'
463            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
464                s = self.thetype('abcba')
465                self.assertEqual(s.update(C(p), C(q)), None)
466                self.assertEqual(s, set(s) | set(p) | set(q))
467
468    def test_ior(self):
469        self.s |= set(self.otherword)
470        for c in (self.word + self.otherword):
471            self.assertIn(c, self.s)
472
473    def test_intersection_update(self):
474        retval = self.s.intersection_update(self.otherword)
475        self.assertEqual(retval, None)
476        for c in (self.word + self.otherword):
477            if c in self.otherword and c in self.word:
478                self.assertIn(c, self.s)
479            else:
480                self.assertNotIn(c, self.s)
481        self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
482        self.assertRaises(TypeError, self.s.intersection_update, [[]])
483        for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
484            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
485                s = self.thetype('abcba')
486                self.assertEqual(s.intersection_update(C(p)), None)
487                self.assertEqual(s, set(q))
488                ss = 'abcba'
489                s = self.thetype(ss)
490                t = 'cbc'
491                self.assertEqual(s.intersection_update(C(p), C(t)), None)
492                self.assertEqual(s, set('abcba')&set(p)&set(t))
493
494    def test_iand(self):
495        self.s &= set(self.otherword)
496        for c in (self.word + self.otherword):
497            if c in self.otherword and c in self.word:
498                self.assertIn(c, self.s)
499            else:
500                self.assertNotIn(c, self.s)
501
502    def test_difference_update(self):
503        retval = self.s.difference_update(self.otherword)
504        self.assertEqual(retval, None)
505        for c in (self.word + self.otherword):
506            if c in self.word and c not in self.otherword:
507                self.assertIn(c, self.s)
508            else:
509                self.assertNotIn(c, self.s)
510        self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
511        self.assertRaises(TypeError, self.s.difference_update, [[]])
512        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
513        for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
514            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
515                s = self.thetype('abcba')
516                self.assertEqual(s.difference_update(C(p)), None)
517                self.assertEqual(s, set(q))
518
519                s = self.thetype('abcdefghih')
520                s.difference_update()
521                self.assertEqual(s, self.thetype('abcdefghih'))
522
523                s = self.thetype('abcdefghih')
524                s.difference_update(C('aba'))
525                self.assertEqual(s, self.thetype('cdefghih'))
526
527                s = self.thetype('abcdefghih')
528                s.difference_update(C('cdc'), C('aba'))
529                self.assertEqual(s, self.thetype('efghih'))
530
531    def test_isub(self):
532        self.s -= set(self.otherword)
533        for c in (self.word + self.otherword):
534            if c in self.word and c not in self.otherword:
535                self.assertIn(c, self.s)
536            else:
537                self.assertNotIn(c, self.s)
538
539    def test_symmetric_difference_update(self):
540        retval = self.s.symmetric_difference_update(self.otherword)
541        self.assertEqual(retval, None)
542        for c in (self.word + self.otherword):
543            if (c in self.word) ^ (c in self.otherword):
544                self.assertIn(c, self.s)
545            else:
546                self.assertNotIn(c, self.s)
547        self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
548        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
549        for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
550            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
551                s = self.thetype('abcba')
552                self.assertEqual(s.symmetric_difference_update(C(p)), None)
553                self.assertEqual(s, set(q))
554
555    def test_ixor(self):
556        self.s ^= set(self.otherword)
557        for c in (self.word + self.otherword):
558            if (c in self.word) ^ (c in self.otherword):
559                self.assertIn(c, self.s)
560            else:
561                self.assertNotIn(c, self.s)
562
563    def test_inplace_on_self(self):
564        t = self.s.copy()
565        t |= t
566        self.assertEqual(t, self.s)
567        t &= t
568        self.assertEqual(t, self.s)
569        t -= t
570        self.assertEqual(t, self.thetype())
571        t = self.s.copy()
572        t ^= t
573        self.assertEqual(t, self.thetype())
574
575    def test_weakref(self):
576        s = self.thetype('gallahad')
577        p = weakref.proxy(s)
578        self.assertEqual(str(p), str(s))
579        s = None
580        self.assertRaises(ReferenceError, str, p)
581
582    @unittest.skipUnless(hasattr(set, "test_c_api"),
583                         'C API test only available in a debug build')
584    def test_c_api(self):
585        self.assertEqual(set().test_c_api(), True)
586
587class SetSubclass(set):
588    pass
589
590class TestSetSubclass(TestSet):
591    thetype = SetSubclass
592
593class SetSubclassWithKeywordArgs(set):
594    def __init__(self, iterable=[], newarg=None):
595        set.__init__(self, iterable)
596
597class TestSetSubclassWithKeywordArgs(TestSet):
598
599    def test_keywords_in_subclass(self):
600        'SF bug #1486663 -- this used to erroneously raise a TypeError'
601        SetSubclassWithKeywordArgs(newarg=1)
602
603class TestFrozenSet(TestJointOps):
604    thetype = frozenset
605
606    def test_init(self):
607        s = self.thetype(self.word)
608        s.__init__(self.otherword)
609        self.assertEqual(s, set(self.word))
610
611    def test_singleton_empty_frozenset(self):
612        f = frozenset()
613        efs = [frozenset(), frozenset([]), frozenset(()), frozenset(''),
614               frozenset(), frozenset([]), frozenset(()), frozenset(''),
615               frozenset(xrange(0)), frozenset(frozenset()),
616               frozenset(f), f]
617        # All of the empty frozensets should have just one id()
618        self.assertEqual(len(set(map(id, efs))), 1)
619
620    def test_constructor_identity(self):
621        s = self.thetype(range(3))
622        t = self.thetype(s)
623        self.assertEqual(id(s), id(t))
624
625    def test_hash(self):
626        self.assertEqual(hash(self.thetype('abcdeb')),
627                         hash(self.thetype('ebecda')))
628
629        # make sure that all permutations give the same hash value
630        n = 100
631        seq = [randrange(n) for i in xrange(n)]
632        results = set()
633        for i in xrange(200):
634            shuffle(seq)
635            results.add(hash(self.thetype(seq)))
636        self.assertEqual(len(results), 1)
637
638    def test_copy(self):
639        dup = self.s.copy()
640        self.assertEqual(id(self.s), id(dup))
641
642    def test_frozen_as_dictkey(self):
643        seq = range(10) + list('abcdefg') + ['apple']
644        key1 = self.thetype(seq)
645        key2 = self.thetype(reversed(seq))
646        self.assertEqual(key1, key2)
647        self.assertNotEqual(id(key1), id(key2))
648        d = {}
649        d[key1] = 42
650        self.assertEqual(d[key2], 42)
651
652    def test_hash_caching(self):
653        f = self.thetype('abcdcda')
654        self.assertEqual(hash(f), hash(f))
655
656    def test_hash_effectiveness(self):
657        n = 13
658        hashvalues = set()
659        addhashvalue = hashvalues.add
660        elemmasks = [(i+1, 1<<i) for i in range(n)]
661        for i in xrange(2**n):
662            addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
663        self.assertEqual(len(hashvalues), 2**n)
664
665class FrozenSetSubclass(frozenset):
666    pass
667
668class TestFrozenSetSubclass(TestFrozenSet):
669    thetype = FrozenSetSubclass
670
671    def test_constructor_identity(self):
672        s = self.thetype(range(3))
673        t = self.thetype(s)
674        self.assertNotEqual(id(s), id(t))
675
676    def test_copy(self):
677        dup = self.s.copy()
678        self.assertNotEqual(id(self.s), id(dup))
679
680    def test_nested_empty_constructor(self):
681        s = self.thetype()
682        t = self.thetype(s)
683        self.assertEqual(s, t)
684
685    def test_singleton_empty_frozenset(self):
686        Frozenset = self.thetype
687        f = frozenset()
688        F = Frozenset()
689        efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
690               Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
691               Frozenset(xrange(0)), Frozenset(Frozenset()),
692               Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)]
693        # All empty frozenset subclass instances should have different ids
694        self.assertEqual(len(set(map(id, efs))), len(efs))
695
696# Tests taken from test_sets.py =============================================
697
698empty_set = set()
699
700#==============================================================================
701
702class TestBasicOps(unittest.TestCase):
703
704    def test_repr(self):
705        if self.repr is not None:
706            self.assertEqual(repr(self.set), self.repr)
707
708    def check_repr_against_values(self):
709        text = repr(self.set)
710        self.assertTrue(text.startswith('{'))
711        self.assertTrue(text.endswith('}'))
712
713        result = text[1:-1].split(', ')
714        result.sort()
715        sorted_repr_values = [repr(value) for value in self.values]
716        sorted_repr_values.sort()
717        self.assertEqual(result, sorted_repr_values)
718
719    def test_print(self):
720        fo = open(test_support.TESTFN, "wb")
721        try:
722            print >> fo, self.set,
723            fo.close()
724            fo = open(test_support.TESTFN, "rb")
725            self.assertEqual(fo.read(), repr(self.set))
726        finally:
727            fo.close()
728            test_support.unlink(test_support.TESTFN)
729
730    def test_length(self):
731        self.assertEqual(len(self.set), self.length)
732
733    def test_self_equality(self):
734        self.assertEqual(self.set, self.set)
735
736    def test_equivalent_equality(self):
737        self.assertEqual(self.set, self.dup)
738
739    def test_copy(self):
740        self.assertEqual(self.set.copy(), self.dup)
741
742    def test_self_union(self):
743        result = self.set | self.set
744        self.assertEqual(result, self.dup)
745
746    def test_empty_union(self):
747        result = self.set | empty_set
748        self.assertEqual(result, self.dup)
749
750    def test_union_empty(self):
751        result = empty_set | self.set
752        self.assertEqual(result, self.dup)
753
754    def test_self_intersection(self):
755        result = self.set & self.set
756        self.assertEqual(result, self.dup)
757
758    def test_empty_intersection(self):
759        result = self.set & empty_set
760        self.assertEqual(result, empty_set)
761
762    def test_intersection_empty(self):
763        result = empty_set & self.set
764        self.assertEqual(result, empty_set)
765
766    def test_self_isdisjoint(self):
767        result = self.set.isdisjoint(self.set)
768        self.assertEqual(result, not self.set)
769
770    def test_empty_isdisjoint(self):
771        result = self.set.isdisjoint(empty_set)
772        self.assertEqual(result, True)
773
774    def test_isdisjoint_empty(self):
775        result = empty_set.isdisjoint(self.set)
776        self.assertEqual(result, True)
777
778    def test_self_symmetric_difference(self):
779        result = self.set ^ self.set
780        self.assertEqual(result, empty_set)
781
782    def test_empty_symmetric_difference(self):
783        result = self.set ^ empty_set
784        self.assertEqual(result, self.set)
785
786    def test_self_difference(self):
787        result = self.set - self.set
788        self.assertEqual(result, empty_set)
789
790    def test_empty_difference(self):
791        result = self.set - empty_set
792        self.assertEqual(result, self.dup)
793
794    def test_empty_difference_rev(self):
795        result = empty_set - self.set
796        self.assertEqual(result, empty_set)
797
798    def test_iteration(self):
799        for v in self.set:
800            self.assertIn(v, self.values)
801        setiter = iter(self.set)
802        # note: __length_hint__ is an internal undocumented API,
803        # don't rely on it in your own programs
804        self.assertEqual(setiter.__length_hint__(), len(self.set))
805
806    def test_pickling(self):
807        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
808            p = pickle.dumps(self.set, proto)
809            copy = pickle.loads(p)
810            self.assertEqual(self.set, copy,
811                             "%s != %s" % (self.set, copy))
812
813#------------------------------------------------------------------------------
814
815class TestBasicOpsEmpty(TestBasicOps):
816    def setUp(self):
817        self.case   = "empty set"
818        self.values = []
819        self.set    = set(self.values)
820        self.dup    = set(self.values)
821        self.length = 0
822        self.repr   = "set([])"
823
824#------------------------------------------------------------------------------
825
826class TestBasicOpsSingleton(TestBasicOps):
827    def setUp(self):
828        self.case   = "unit set (number)"
829        self.values = [3]
830        self.set    = set(self.values)
831        self.dup    = set(self.values)
832        self.length = 1
833        self.repr   = "set([3])"
834
835    def test_in(self):
836        self.assertIn(3, self.set)
837
838    def test_not_in(self):
839        self.assertNotIn(2, self.set)
840
841#------------------------------------------------------------------------------
842
843class TestBasicOpsTuple(TestBasicOps):
844    def setUp(self):
845        self.case   = "unit set (tuple)"
846        self.values = [(0, "zero")]
847        self.set    = set(self.values)
848        self.dup    = set(self.values)
849        self.length = 1
850        self.repr   = "set([(0, 'zero')])"
851
852    def test_in(self):
853        self.assertIn((0, "zero"), self.set)
854
855    def test_not_in(self):
856        self.assertNotIn(9, self.set)
857
858#------------------------------------------------------------------------------
859
860class TestBasicOpsTriple(TestBasicOps):
861    def setUp(self):
862        self.case   = "triple set"
863        self.values = [0, "zero", operator.add]
864        self.set    = set(self.values)
865        self.dup    = set(self.values)
866        self.length = 3
867        self.repr   = None
868
869#------------------------------------------------------------------------------
870
871class TestBasicOpsString(TestBasicOps):
872    def setUp(self):
873        self.case   = "string set"
874        self.values = ["a", "b", "c"]
875        self.set    = set(self.values)
876        self.dup    = set(self.values)
877        self.length = 3
878
879    def test_repr(self):
880        self.check_repr_against_values()
881
882#------------------------------------------------------------------------------
883
884class TestBasicOpsUnicode(TestBasicOps):
885    def setUp(self):
886        self.case   = "unicode set"
887        self.values = [u"a", u"b", u"c"]
888        self.set    = set(self.values)
889        self.dup    = set(self.values)
890        self.length = 3
891
892    def test_repr(self):
893        self.check_repr_against_values()
894
895#------------------------------------------------------------------------------
896
897class TestBasicOpsMixedStringUnicode(TestBasicOps):
898    def setUp(self):
899        self.case   = "string and bytes set"
900        self.values = ["a", "b", u"a", u"b"]
901        self.set    = set(self.values)
902        self.dup    = set(self.values)
903        self.length = 4
904
905    def test_repr(self):
906        with test_support.check_warnings():
907            self.check_repr_against_values()
908
909#==============================================================================
910
911def baditer():
912    raise TypeError
913    yield True
914
915def gooditer():
916    yield True
917
918class TestExceptionPropagation(unittest.TestCase):
919    """SF 628246:  Set constructor should not trap iterator TypeErrors"""
920
921    def test_instanceWithException(self):
922        self.assertRaises(TypeError, set, baditer())
923
924    def test_instancesWithoutException(self):
925        # All of these iterables should load without exception.
926        set([1,2,3])
927        set((1,2,3))
928        set({'one':1, 'two':2, 'three':3})
929        set(xrange(3))
930        set('abc')
931        set(gooditer())
932
933    def test_changingSizeWhileIterating(self):
934        s = set([1,2,3])
935        try:
936            for i in s:
937                s.update([4])
938        except RuntimeError:
939            pass
940        else:
941            self.fail("no exception when changing size during iteration")
942
943#==============================================================================
944
945class TestSetOfSets(unittest.TestCase):
946    def test_constructor(self):
947        inner = frozenset([1])
948        outer = set([inner])
949        element = outer.pop()
950        self.assertEqual(type(element), frozenset)
951        outer.add(inner)        # Rebuild set of sets with .add method
952        outer.remove(inner)
953        self.assertEqual(outer, set())   # Verify that remove worked
954        outer.discard(inner)    # Absence of KeyError indicates working fine
955
956#==============================================================================
957
958class TestBinaryOps(unittest.TestCase):
959    def setUp(self):
960        self.set = set((2, 4, 6))
961
962    def test_eq(self):              # SF bug 643115
963        self.assertEqual(self.set, set({2:1,4:3,6:5}))
964
965    def test_union_subset(self):
966        result = self.set | set([2])
967        self.assertEqual(result, set((2, 4, 6)))
968
969    def test_union_superset(self):
970        result = self.set | set([2, 4, 6, 8])
971        self.assertEqual(result, set([2, 4, 6, 8]))
972
973    def test_union_overlap(self):
974        result = self.set | set([3, 4, 5])
975        self.assertEqual(result, set([2, 3, 4, 5, 6]))
976
977    def test_union_non_overlap(self):
978        result = self.set | set([8])
979        self.assertEqual(result, set([2, 4, 6, 8]))
980
981    def test_intersection_subset(self):
982        result = self.set & set((2, 4))
983        self.assertEqual(result, set((2, 4)))
984
985    def test_intersection_superset(self):
986        result = self.set & set([2, 4, 6, 8])
987        self.assertEqual(result, set([2, 4, 6]))
988
989    def test_intersection_overlap(self):
990        result = self.set & set([3, 4, 5])
991        self.assertEqual(result, set([4]))
992
993    def test_intersection_non_overlap(self):
994        result = self.set & set([8])
995        self.assertEqual(result, empty_set)
996
997    def test_isdisjoint_subset(self):
998        result = self.set.isdisjoint(set((2, 4)))
999        self.assertEqual(result, False)
1000
1001    def test_isdisjoint_superset(self):
1002        result = self.set.isdisjoint(set([2, 4, 6, 8]))
1003        self.assertEqual(result, False)
1004
1005    def test_isdisjoint_overlap(self):
1006        result = self.set.isdisjoint(set([3, 4, 5]))
1007        self.assertEqual(result, False)
1008
1009    def test_isdisjoint_non_overlap(self):
1010        result = self.set.isdisjoint(set([8]))
1011        self.assertEqual(result, True)
1012
1013    def test_sym_difference_subset(self):
1014        result = self.set ^ set((2, 4))
1015        self.assertEqual(result, set([6]))
1016
1017    def test_sym_difference_superset(self):
1018        result = self.set ^ set((2, 4, 6, 8))
1019        self.assertEqual(result, set([8]))
1020
1021    def test_sym_difference_overlap(self):
1022        result = self.set ^ set((3, 4, 5))
1023        self.assertEqual(result, set([2, 3, 5, 6]))
1024
1025    def test_sym_difference_non_overlap(self):
1026        result = self.set ^ set([8])
1027        self.assertEqual(result, set([2, 4, 6, 8]))
1028
1029    def test_cmp(self):
1030        a, b = set('a'), set('b')
1031        self.assertRaises(TypeError, cmp, a, b)
1032
1033        # You can view this as a buglet:  cmp(a, a) does not raise TypeError,
1034        # because __eq__ is tried before __cmp__, and a.__eq__(a) returns True,
1035        # which Python thinks is good enough to synthesize a cmp() result
1036        # without calling __cmp__.
1037        self.assertEqual(cmp(a, a), 0)
1038
1039
1040#==============================================================================
1041
1042class TestUpdateOps(unittest.TestCase):
1043    def setUp(self):
1044        self.set = set((2, 4, 6))
1045
1046    def test_union_subset(self):
1047        self.set |= set([2])
1048        self.assertEqual(self.set, set((2, 4, 6)))
1049
1050    def test_union_superset(self):
1051        self.set |= set([2, 4, 6, 8])
1052        self.assertEqual(self.set, set([2, 4, 6, 8]))
1053
1054    def test_union_overlap(self):
1055        self.set |= set([3, 4, 5])
1056        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1057
1058    def test_union_non_overlap(self):
1059        self.set |= set([8])
1060        self.assertEqual(self.set, set([2, 4, 6, 8]))
1061
1062    def test_union_method_call(self):
1063        self.set.update(set([3, 4, 5]))
1064        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1065
1066    def test_intersection_subset(self):
1067        self.set &= set((2, 4))
1068        self.assertEqual(self.set, set((2, 4)))
1069
1070    def test_intersection_superset(self):
1071        self.set &= set([2, 4, 6, 8])
1072        self.assertEqual(self.set, set([2, 4, 6]))
1073
1074    def test_intersection_overlap(self):
1075        self.set &= set([3, 4, 5])
1076        self.assertEqual(self.set, set([4]))
1077
1078    def test_intersection_non_overlap(self):
1079        self.set &= set([8])
1080        self.assertEqual(self.set, empty_set)
1081
1082    def test_intersection_method_call(self):
1083        self.set.intersection_update(set([3, 4, 5]))
1084        self.assertEqual(self.set, set([4]))
1085
1086    def test_sym_difference_subset(self):
1087        self.set ^= set((2, 4))
1088        self.assertEqual(self.set, set([6]))
1089
1090    def test_sym_difference_superset(self):
1091        self.set ^= set((2, 4, 6, 8))
1092        self.assertEqual(self.set, set([8]))
1093
1094    def test_sym_difference_overlap(self):
1095        self.set ^= set((3, 4, 5))
1096        self.assertEqual(self.set, set([2, 3, 5, 6]))
1097
1098    def test_sym_difference_non_overlap(self):
1099        self.set ^= set([8])
1100        self.assertEqual(self.set, set([2, 4, 6, 8]))
1101
1102    def test_sym_difference_method_call(self):
1103        self.set.symmetric_difference_update(set([3, 4, 5]))
1104        self.assertEqual(self.set, set([2, 3, 5, 6]))
1105
1106    def test_difference_subset(self):
1107        self.set -= set((2, 4))
1108        self.assertEqual(self.set, set([6]))
1109
1110    def test_difference_superset(self):
1111        self.set -= set((2, 4, 6, 8))
1112        self.assertEqual(self.set, set([]))
1113
1114    def test_difference_overlap(self):
1115        self.set -= set((3, 4, 5))
1116        self.assertEqual(self.set, set([2, 6]))
1117
1118    def test_difference_non_overlap(self):
1119        self.set -= set([8])
1120        self.assertEqual(self.set, set([2, 4, 6]))
1121
1122    def test_difference_method_call(self):
1123        self.set.difference_update(set([3, 4, 5]))
1124        self.assertEqual(self.set, set([2, 6]))
1125
1126#==============================================================================
1127
1128class TestMutate(unittest.TestCase):
1129    def setUp(self):
1130        self.values = ["a", "b", "c"]
1131        self.set = set(self.values)
1132
1133    def test_add_present(self):
1134        self.set.add("c")
1135        self.assertEqual(self.set, set("abc"))
1136
1137    def test_add_absent(self):
1138        self.set.add("d")
1139        self.assertEqual(self.set, set("abcd"))
1140
1141    def test_add_until_full(self):
1142        tmp = set()
1143        expected_len = 0
1144        for v in self.values:
1145            tmp.add(v)
1146            expected_len += 1
1147            self.assertEqual(len(tmp), expected_len)
1148        self.assertEqual(tmp, self.set)
1149
1150    def test_remove_present(self):
1151        self.set.remove("b")
1152        self.assertEqual(self.set, set("ac"))
1153
1154    def test_remove_absent(self):
1155        try:
1156            self.set.remove("d")
1157            self.fail("Removing missing element should have raised LookupError")
1158        except LookupError:
1159            pass
1160
1161    def test_remove_until_empty(self):
1162        expected_len = len(self.set)
1163        for v in self.values:
1164            self.set.remove(v)
1165            expected_len -= 1
1166            self.assertEqual(len(self.set), expected_len)
1167
1168    def test_discard_present(self):
1169        self.set.discard("c")
1170        self.assertEqual(self.set, set("ab"))
1171
1172    def test_discard_absent(self):
1173        self.set.discard("d")
1174        self.assertEqual(self.set, set("abc"))
1175
1176    def test_clear(self):
1177        self.set.clear()
1178        self.assertEqual(len(self.set), 0)
1179
1180    def test_pop(self):
1181        popped = {}
1182        while self.set:
1183            popped[self.set.pop()] = None
1184        self.assertEqual(len(popped), len(self.values))
1185        for v in self.values:
1186            self.assertIn(v, popped)
1187
1188    def test_update_empty_tuple(self):
1189        self.set.update(())
1190        self.assertEqual(self.set, set(self.values))
1191
1192    def test_update_unit_tuple_overlap(self):
1193        self.set.update(("a",))
1194        self.assertEqual(self.set, set(self.values))
1195
1196    def test_update_unit_tuple_non_overlap(self):
1197        self.set.update(("a", "z"))
1198        self.assertEqual(self.set, set(self.values + ["z"]))
1199
1200#==============================================================================
1201
1202class TestSubsets(unittest.TestCase):
1203
1204    case2method = {"<=": "issubset",
1205                   ">=": "issuperset",
1206                  }
1207
1208    reverse = {"==": "==",
1209               "!=": "!=",
1210               "<":  ">",
1211               ">":  "<",
1212               "<=": ">=",
1213               ">=": "<=",
1214              }
1215
1216    def test_issubset(self):
1217        x = self.left
1218        y = self.right
1219        for case in "!=", "==", "<", "<=", ">", ">=":
1220            expected = case in self.cases
1221            # Test the binary infix spelling.
1222            result = eval("x" + case + "y", locals())
1223            self.assertEqual(result, expected)
1224            # Test the "friendly" method-name spelling, if one exists.
1225            if case in TestSubsets.case2method:
1226                method = getattr(x, TestSubsets.case2method[case])
1227                result = method(y)
1228                self.assertEqual(result, expected)
1229
1230            # Now do the same for the operands reversed.
1231            rcase = TestSubsets.reverse[case]
1232            result = eval("y" + rcase + "x", locals())
1233            self.assertEqual(result, expected)
1234            if rcase in TestSubsets.case2method:
1235                method = getattr(y, TestSubsets.case2method[rcase])
1236                result = method(x)
1237                self.assertEqual(result, expected)
1238#------------------------------------------------------------------------------
1239
1240class TestSubsetEqualEmpty(TestSubsets):
1241    left  = set()
1242    right = set()
1243    name  = "both empty"
1244    cases = "==", "<=", ">="
1245
1246#------------------------------------------------------------------------------
1247
1248class TestSubsetEqualNonEmpty(TestSubsets):
1249    left  = set([1, 2])
1250    right = set([1, 2])
1251    name  = "equal pair"
1252    cases = "==", "<=", ">="
1253
1254#------------------------------------------------------------------------------
1255
1256class TestSubsetEmptyNonEmpty(TestSubsets):
1257    left  = set()
1258    right = set([1, 2])
1259    name  = "one empty, one non-empty"
1260    cases = "!=", "<", "<="
1261
1262#------------------------------------------------------------------------------
1263
1264class TestSubsetPartial(TestSubsets):
1265    left  = set([1])
1266    right = set([1, 2])
1267    name  = "one a non-empty proper subset of other"
1268    cases = "!=", "<", "<="
1269
1270#------------------------------------------------------------------------------
1271
1272class TestSubsetNonOverlap(TestSubsets):
1273    left  = set([1])
1274    right = set([2])
1275    name  = "neither empty, neither contains"
1276    cases = "!="
1277
1278#==============================================================================
1279
1280class TestOnlySetsInBinaryOps(unittest.TestCase):
1281
1282    def test_eq_ne(self):
1283        # Unlike the others, this is testing that == and != *are* allowed.
1284        self.assertEqual(self.other == self.set, False)
1285        self.assertEqual(self.set == self.other, False)
1286        self.assertEqual(self.other != self.set, True)
1287        self.assertEqual(self.set != self.other, True)
1288
1289    def test_update_operator(self):
1290        try:
1291            self.set |= self.other
1292        except TypeError:
1293            pass
1294        else:
1295            self.fail("expected TypeError")
1296
1297    def test_update(self):
1298        if self.otherIsIterable:
1299            self.set.update(self.other)
1300        else:
1301            self.assertRaises(TypeError, self.set.update, self.other)
1302
1303    def test_union(self):
1304        self.assertRaises(TypeError, lambda: self.set | self.other)
1305        self.assertRaises(TypeError, lambda: self.other | self.set)
1306        if self.otherIsIterable:
1307            self.set.union(self.other)
1308        else:
1309            self.assertRaises(TypeError, self.set.union, self.other)
1310
1311    def test_intersection_update_operator(self):
1312        try:
1313            self.set &= self.other
1314        except TypeError:
1315            pass
1316        else:
1317            self.fail("expected TypeError")
1318
1319    def test_intersection_update(self):
1320        if self.otherIsIterable:
1321            self.set.intersection_update(self.other)
1322        else:
1323            self.assertRaises(TypeError,
1324                              self.set.intersection_update,
1325                              self.other)
1326
1327    def test_intersection(self):
1328        self.assertRaises(TypeError, lambda: self.set & self.other)
1329        self.assertRaises(TypeError, lambda: self.other & self.set)
1330        if self.otherIsIterable:
1331            self.set.intersection(self.other)
1332        else:
1333            self.assertRaises(TypeError, self.set.intersection, self.other)
1334
1335    def test_sym_difference_update_operator(self):
1336        try:
1337            self.set ^= self.other
1338        except TypeError:
1339            pass
1340        else:
1341            self.fail("expected TypeError")
1342
1343    def test_sym_difference_update(self):
1344        if self.otherIsIterable:
1345            self.set.symmetric_difference_update(self.other)
1346        else:
1347            self.assertRaises(TypeError,
1348                              self.set.symmetric_difference_update,
1349                              self.other)
1350
1351    def test_sym_difference(self):
1352        self.assertRaises(TypeError, lambda: self.set ^ self.other)
1353        self.assertRaises(TypeError, lambda: self.other ^ self.set)
1354        if self.otherIsIterable:
1355            self.set.symmetric_difference(self.other)
1356        else:
1357            self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
1358
1359    def test_difference_update_operator(self):
1360        try:
1361            self.set -= self.other
1362        except TypeError:
1363            pass
1364        else:
1365            self.fail("expected TypeError")
1366
1367    def test_difference_update(self):
1368        if self.otherIsIterable:
1369            self.set.difference_update(self.other)
1370        else:
1371            self.assertRaises(TypeError,
1372                              self.set.difference_update,
1373                              self.other)
1374
1375    def test_difference(self):
1376        self.assertRaises(TypeError, lambda: self.set - self.other)
1377        self.assertRaises(TypeError, lambda: self.other - self.set)
1378        if self.otherIsIterable:
1379            self.set.difference(self.other)
1380        else:
1381            self.assertRaises(TypeError, self.set.difference, self.other)
1382
1383#------------------------------------------------------------------------------
1384
1385class TestOnlySetsNumeric(TestOnlySetsInBinaryOps):
1386    def setUp(self):
1387        self.set   = set((1, 2, 3))
1388        self.other = 19
1389        self.otherIsIterable = False
1390
1391#------------------------------------------------------------------------------
1392
1393class TestOnlySetsDict(TestOnlySetsInBinaryOps):
1394    def setUp(self):
1395        self.set   = set((1, 2, 3))
1396        self.other = {1:2, 3:4}
1397        self.otherIsIterable = True
1398
1399#------------------------------------------------------------------------------
1400
1401class TestOnlySetsTuple(TestOnlySetsInBinaryOps):
1402    def setUp(self):
1403        self.set   = set((1, 2, 3))
1404        self.other = (2, 4, 6)
1405        self.otherIsIterable = True
1406
1407#------------------------------------------------------------------------------
1408
1409class TestOnlySetsString(TestOnlySetsInBinaryOps):
1410    def setUp(self):
1411        self.set   = set((1, 2, 3))
1412        self.other = 'abc'
1413        self.otherIsIterable = True
1414
1415#------------------------------------------------------------------------------
1416
1417class TestOnlySetsGenerator(TestOnlySetsInBinaryOps):
1418    def setUp(self):
1419        def gen():
1420            for i in xrange(0, 10, 2):
1421                yield i
1422        self.set   = set((1, 2, 3))
1423        self.other = gen()
1424        self.otherIsIterable = True
1425
1426#==============================================================================
1427
1428class TestCopying(unittest.TestCase):
1429
1430    def test_copy(self):
1431        dup = list(self.set.copy())
1432        self.assertEqual(len(dup), len(self.set))
1433        for el in self.set:
1434            self.assertIn(el, dup)
1435            pos = dup.index(el)
1436            self.assertIs(el, dup.pop(pos))
1437        self.assertFalse(dup)
1438
1439    def test_deep_copy(self):
1440        dup = copy.deepcopy(self.set)
1441        self.assertSetEqual(dup, self.set)
1442
1443#------------------------------------------------------------------------------
1444
1445class TestCopyingEmpty(TestCopying):
1446    def setUp(self):
1447        self.set = set()
1448
1449#------------------------------------------------------------------------------
1450
1451class TestCopyingSingleton(TestCopying):
1452    def setUp(self):
1453        self.set = set(["hello"])
1454
1455#------------------------------------------------------------------------------
1456
1457class TestCopyingTriple(TestCopying):
1458    def setUp(self):
1459        self.set = set(["zero", 0, None])
1460
1461#------------------------------------------------------------------------------
1462
1463class TestCopyingTuple(TestCopying):
1464    def setUp(self):
1465        self.set = set([(1, 2)])
1466
1467#------------------------------------------------------------------------------
1468
1469class TestCopyingNested(TestCopying):
1470    def setUp(self):
1471        self.set = set([((1, 2), (3, 4))])
1472
1473#==============================================================================
1474
1475class TestIdentities(unittest.TestCase):
1476    def setUp(self):
1477        self.a = set('abracadabra')
1478        self.b = set('alacazam')
1479
1480    def test_binopsVsSubsets(self):
1481        a, b = self.a, self.b
1482        self.assertTrue(a - b < a)
1483        self.assertTrue(b - a < b)
1484        self.assertTrue(a & b < a)
1485        self.assertTrue(a & b < b)
1486        self.assertTrue(a | b > a)
1487        self.assertTrue(a | b > b)
1488        self.assertTrue(a ^ b < a | b)
1489
1490    def test_commutativity(self):
1491        a, b = self.a, self.b
1492        self.assertEqual(a&b, b&a)
1493        self.assertEqual(a|b, b|a)
1494        self.assertEqual(a^b, b^a)
1495        if a != b:
1496            self.assertNotEqual(a-b, b-a)
1497
1498    def test_summations(self):
1499        # check that sums of parts equal the whole
1500        a, b = self.a, self.b
1501        self.assertEqual((a-b)|(a&b)|(b-a), a|b)
1502        self.assertEqual((a&b)|(a^b), a|b)
1503        self.assertEqual(a|(b-a), a|b)
1504        self.assertEqual((a-b)|b, a|b)
1505        self.assertEqual((a-b)|(a&b), a)
1506        self.assertEqual((b-a)|(a&b), b)
1507        self.assertEqual((a-b)|(b-a), a^b)
1508
1509    def test_exclusion(self):
1510        # check that inverse operations show non-overlap
1511        a, b, zero = self.a, self.b, set()
1512        self.assertEqual((a-b)&b, zero)
1513        self.assertEqual((b-a)&a, zero)
1514        self.assertEqual((a&b)&(a^b), zero)
1515
1516# Tests derived from test_itertools.py =======================================
1517
1518def R(seqn):
1519    'Regular generator'
1520    for i in seqn:
1521        yield i
1522
1523class G:
1524    'Sequence using __getitem__'
1525    def __init__(self, seqn):
1526        self.seqn = seqn
1527    def __getitem__(self, i):
1528        return self.seqn[i]
1529
1530class I:
1531    'Sequence using iterator protocol'
1532    def __init__(self, seqn):
1533        self.seqn = seqn
1534        self.i = 0
1535    def __iter__(self):
1536        return self
1537    def next(self):
1538        if self.i >= len(self.seqn): raise StopIteration
1539        v = self.seqn[self.i]
1540        self.i += 1
1541        return v
1542
1543class Ig:
1544    'Sequence using iterator protocol defined with a generator'
1545    def __init__(self, seqn):
1546        self.seqn = seqn
1547        self.i = 0
1548    def __iter__(self):
1549        for val in self.seqn:
1550            yield val
1551
1552class X:
1553    'Missing __getitem__ and __iter__'
1554    def __init__(self, seqn):
1555        self.seqn = seqn
1556        self.i = 0
1557    def next(self):
1558        if self.i >= len(self.seqn): raise StopIteration
1559        v = self.seqn[self.i]
1560        self.i += 1
1561        return v
1562
1563class N:
1564    'Iterator missing next()'
1565    def __init__(self, seqn):
1566        self.seqn = seqn
1567        self.i = 0
1568    def __iter__(self):
1569        return self
1570
1571class E:
1572    'Test propagation of exceptions'
1573    def __init__(self, seqn):
1574        self.seqn = seqn
1575        self.i = 0
1576    def __iter__(self):
1577        return self
1578    def next(self):
1579        3 // 0
1580
1581class S:
1582    'Test immediate stop'
1583    def __init__(self, seqn):
1584        pass
1585    def __iter__(self):
1586        return self
1587    def next(self):
1588        raise StopIteration
1589
1590from itertools import chain, imap
1591def L(seqn):
1592    'Test multiple tiers of iterators'
1593    return chain(imap(lambda x:x, R(Ig(G(seqn)))))
1594
1595class TestVariousIteratorArgs(unittest.TestCase):
1596
1597    def test_constructor(self):
1598        for cons in (set, frozenset):
1599            for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1600                for g in (G, I, Ig, S, L, R):
1601                    self.assertSetEqual(cons(g(s)), set(g(s)))
1602                self.assertRaises(TypeError, cons , X(s))
1603                self.assertRaises(TypeError, cons , N(s))
1604                self.assertRaises(ZeroDivisionError, cons , E(s))
1605
1606    def test_inline_methods(self):
1607        s = set('november')
1608        for data in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5), 'december'):
1609            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
1610                for g in (G, I, Ig, L, R):
1611                    expected = meth(data)
1612                    actual = meth(g(data))
1613                    if isinstance(expected, bool):
1614                        self.assertEqual(actual, expected)
1615                    else:
1616                        self.assertSetEqual(actual, expected)
1617                self.assertRaises(TypeError, meth, X(s))
1618                self.assertRaises(TypeError, meth, N(s))
1619                self.assertRaises(ZeroDivisionError, meth, E(s))
1620
1621    def test_inplace_methods(self):
1622        for data in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5), 'december'):
1623            for methname in ('update', 'intersection_update',
1624                             'difference_update', 'symmetric_difference_update'):
1625                for g in (G, I, Ig, S, L, R):
1626                    s = set('january')
1627                    t = s.copy()
1628                    getattr(s, methname)(list(g(data)))
1629                    getattr(t, methname)(g(data))
1630                    self.assertSetEqual(s, t)
1631
1632                self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
1633                self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
1634                self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
1635
1636class bad_eq:
1637    def __eq__(self, other):
1638        if be_bad:
1639            set2.clear()
1640            raise ZeroDivisionError
1641        return self is other
1642    def __hash__(self):
1643        return 0
1644
1645class bad_dict_clear:
1646    def __eq__(self, other):
1647        if be_bad:
1648            dict2.clear()
1649        return self is other
1650    def __hash__(self):
1651        return 0
1652
1653class TestWeirdBugs(unittest.TestCase):
1654    def test_8420_set_merge(self):
1655        # This used to segfault
1656        global be_bad, set2, dict2
1657        be_bad = False
1658        set1 = {bad_eq()}
1659        set2 = {bad_eq() for i in range(75)}
1660        be_bad = True
1661        self.assertRaises(ZeroDivisionError, set1.update, set2)
1662
1663        be_bad = False
1664        set1 = {bad_dict_clear()}
1665        dict2 = {bad_dict_clear(): None}
1666        be_bad = True
1667        set1.symmetric_difference_update(dict2)
1668
1669    def test_iter_and_mutate(self):
1670        # Issue #24581
1671        s = set(range(100))
1672        s.clear()
1673        s.update(range(100))
1674        si = iter(s)
1675        s.clear()
1676        a = list(range(100))
1677        s.update(range(100))
1678        list(si)
1679
1680# Application tests (based on David Eppstein's graph recipes ====================================
1681
1682def powerset(U):
1683    """Generates all subsets of a set or sequence U."""
1684    U = iter(U)
1685    try:
1686        x = frozenset([U.next()])
1687        for S in powerset(U):
1688            yield S
1689            yield S | x
1690    except StopIteration:
1691        yield frozenset()
1692
1693def cube(n):
1694    """Graph of n-dimensional hypercube."""
1695    singletons = [frozenset([x]) for x in range(n)]
1696    return dict([(x, frozenset([x^s for s in singletons]))
1697                 for x in powerset(range(n))])
1698
1699def linegraph(G):
1700    """Graph, the vertices of which are edges of G,
1701    with two vertices being adjacent iff the corresponding
1702    edges share a vertex."""
1703    L = {}
1704    for x in G:
1705        for y in G[x]:
1706            nx = [frozenset([x,z]) for z in G[x] if z != y]
1707            ny = [frozenset([y,z]) for z in G[y] if z != x]
1708            L[frozenset([x,y])] = frozenset(nx+ny)
1709    return L
1710
1711def faces(G):
1712    'Return a set of faces in G.  Where a face is a set of vertices on that face'
1713    # currently limited to triangles,squares, and pentagons
1714    f = set()
1715    for v1, edges in G.items():
1716        for v2 in edges:
1717            for v3 in G[v2]:
1718                if v1 == v3:
1719                    continue
1720                if v1 in G[v3]:
1721                    f.add(frozenset([v1, v2, v3]))
1722                else:
1723                    for v4 in G[v3]:
1724                        if v4 == v2:
1725                            continue
1726                        if v1 in G[v4]:
1727                            f.add(frozenset([v1, v2, v3, v4]))
1728                        else:
1729                            for v5 in G[v4]:
1730                                if v5 == v3 or v5 == v2:
1731                                    continue
1732                                if v1 in G[v5]:
1733                                    f.add(frozenset([v1, v2, v3, v4, v5]))
1734    return f
1735
1736
1737class TestGraphs(unittest.TestCase):
1738
1739    def test_cube(self):
1740
1741        g = cube(3)                             # vert --> {v1, v2, v3}
1742        vertices1 = set(g)
1743        self.assertEqual(len(vertices1), 8)     # eight vertices
1744        for edge in g.values():
1745            self.assertEqual(len(edge), 3)      # each vertex connects to three edges
1746        vertices2 = set(v for edges in g.values() for v in edges)
1747        self.assertEqual(vertices1, vertices2)  # edge vertices in original set
1748
1749        cubefaces = faces(g)
1750        self.assertEqual(len(cubefaces), 6)     # six faces
1751        for face in cubefaces:
1752            self.assertEqual(len(face), 4)      # each face is a square
1753
1754    def test_cuboctahedron(self):
1755
1756        # http://en.wikipedia.org/wiki/Cuboctahedron
1757        # 8 triangular faces and 6 square faces
1758        # 12 identical vertices each connecting a triangle and square
1759
1760        g = cube(3)
1761        cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
1762        self.assertEqual(len(cuboctahedron), 12)# twelve vertices
1763
1764        vertices = set(cuboctahedron)
1765        for edges in cuboctahedron.values():
1766            self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
1767        othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
1768        self.assertEqual(vertices, othervertices)   # edge vertices in original set
1769
1770        cubofaces = faces(cuboctahedron)
1771        facesizes = collections.defaultdict(int)
1772        for face in cubofaces:
1773            facesizes[len(face)] += 1
1774        self.assertEqual(facesizes[3], 8)       # eight triangular faces
1775        self.assertEqual(facesizes[4], 6)       # six square faces
1776
1777        for vertex in cuboctahedron:
1778            edge = vertex                       # Cuboctahedron vertices are edges in Cube
1779            self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
1780            for cubevert in edge:
1781                self.assertIn(cubevert, g)
1782
1783
1784#==============================================================================
1785
1786def test_main(verbose=None):
1787    test_classes = (
1788        TestSet,
1789        TestSetSubclass,
1790        TestSetSubclassWithKeywordArgs,
1791        TestFrozenSet,
1792        TestFrozenSetSubclass,
1793        TestSetOfSets,
1794        TestExceptionPropagation,
1795        TestBasicOpsEmpty,
1796        TestBasicOpsSingleton,
1797        TestBasicOpsTuple,
1798        TestBasicOpsTriple,
1799        TestBinaryOps,
1800        TestUpdateOps,
1801        TestMutate,
1802        TestSubsetEqualEmpty,
1803        TestSubsetEqualNonEmpty,
1804        TestSubsetEmptyNonEmpty,
1805        TestSubsetPartial,
1806        TestSubsetNonOverlap,
1807        TestOnlySetsNumeric,
1808        TestOnlySetsDict,
1809        TestOnlySetsTuple,
1810        TestOnlySetsString,
1811        TestOnlySetsGenerator,
1812        TestCopyingEmpty,
1813        TestCopyingSingleton,
1814        TestCopyingTriple,
1815        TestCopyingTuple,
1816        TestCopyingNested,
1817        TestIdentities,
1818        TestVariousIteratorArgs,
1819        TestGraphs,
1820        TestWeirdBugs,
1821        )
1822
1823    test_support.run_unittest(*test_classes)
1824
1825    # verify reference counting
1826    if verbose and hasattr(sys, "gettotalrefcount"):
1827        import gc
1828        counts = [None] * 5
1829        for i in xrange(len(counts)):
1830            test_support.run_unittest(*test_classes)
1831            gc.collect()
1832            counts[i] = sys.gettotalrefcount()
1833        print counts
1834
1835if __name__ == "__main__":
1836    test_main(verbose=True)
1837