• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from weakref import WeakSet
3import string
4from collections import UserString as ustr
5from collections.abc import Set, MutableSet
6import gc
7import contextlib
8from test import support
9
10
11class Foo:
12    pass
13
14class RefCycle:
15    def __init__(self):
16        self.cycle = self
17
18
19class TestWeakSet(unittest.TestCase):
20
21    def setUp(self):
22        # need to keep references to them
23        self.items = [ustr(c) for c in ('a', 'b', 'c')]
24        self.items2 = [ustr(c) for c in ('x', 'y', 'z')]
25        self.ab_items = [ustr(c) for c in 'ab']
26        self.abcde_items = [ustr(c) for c in 'abcde']
27        self.def_items = [ustr(c) for c in 'def']
28        self.ab_weakset = WeakSet(self.ab_items)
29        self.abcde_weakset = WeakSet(self.abcde_items)
30        self.def_weakset = WeakSet(self.def_items)
31        self.letters = [ustr(c) for c in string.ascii_letters]
32        self.s = WeakSet(self.items)
33        self.d = dict.fromkeys(self.items)
34        self.obj = ustr('F')
35        self.fs = WeakSet([self.obj])
36
37    def test_methods(self):
38        weaksetmethods = dir(WeakSet)
39        for method in dir(set):
40            if method == 'test_c_api' or method.startswith('_'):
41                continue
42            self.assertIn(method, weaksetmethods,
43                         "WeakSet missing method " + method)
44
45    def test_new_or_init(self):
46        self.assertRaises(TypeError, WeakSet, [], 2)
47
48    def test_len(self):
49        self.assertEqual(len(self.s), len(self.d))
50        self.assertEqual(len(self.fs), 1)
51        del self.obj
52        support.gc_collect()  # For PyPy or other GCs.
53        self.assertEqual(len(self.fs), 0)
54
55    def test_contains(self):
56        for c in self.letters:
57            self.assertEqual(c in self.s, c in self.d)
58        # 1 is not weakref'able, but that TypeError is caught by __contains__
59        self.assertNotIn(1, self.s)
60        self.assertIn(self.obj, self.fs)
61        del self.obj
62        support.gc_collect()  # For PyPy or other GCs.
63        self.assertNotIn(ustr('F'), self.fs)
64
65    def test_union(self):
66        u = self.s.union(self.items2)
67        for c in self.letters:
68            self.assertEqual(c in u, c in self.d or c in self.items2)
69        self.assertEqual(self.s, WeakSet(self.items))
70        self.assertEqual(type(u), WeakSet)
71        self.assertRaises(TypeError, self.s.union, [[]])
72        for C in set, frozenset, dict.fromkeys, list, tuple:
73            x = WeakSet(self.items + self.items2)
74            c = C(self.items2)
75            self.assertEqual(self.s.union(c), x)
76            del c
77        self.assertEqual(len(u), len(self.items) + len(self.items2))
78        self.items2.pop()
79        gc.collect()
80        self.assertEqual(len(u), len(self.items) + len(self.items2))
81
82    def test_or(self):
83        i = self.s.union(self.items2)
84        self.assertEqual(self.s | set(self.items2), i)
85        self.assertEqual(self.s | frozenset(self.items2), i)
86
87    def test_intersection(self):
88        s = WeakSet(self.letters)
89        i = s.intersection(self.items2)
90        for c in self.letters:
91            self.assertEqual(c in i, c in self.items2 and c in self.letters)
92        self.assertEqual(s, WeakSet(self.letters))
93        self.assertEqual(type(i), WeakSet)
94        for C in set, frozenset, dict.fromkeys, list, tuple:
95            x = WeakSet([])
96            self.assertEqual(i.intersection(C(self.items)), x)
97        self.assertEqual(len(i), len(self.items2))
98        self.items2.pop()
99        gc.collect()
100        self.assertEqual(len(i), len(self.items2))
101
102    def test_isdisjoint(self):
103        self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
104        self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
105
106    def test_and(self):
107        i = self.s.intersection(self.items2)
108        self.assertEqual(self.s & set(self.items2), i)
109        self.assertEqual(self.s & frozenset(self.items2), i)
110
111    def test_difference(self):
112        i = self.s.difference(self.items2)
113        for c in self.letters:
114            self.assertEqual(c in i, c in self.d and c not in self.items2)
115        self.assertEqual(self.s, WeakSet(self.items))
116        self.assertEqual(type(i), WeakSet)
117        self.assertRaises(TypeError, self.s.difference, [[]])
118
119    def test_sub(self):
120        i = self.s.difference(self.items2)
121        self.assertEqual(self.s - set(self.items2), i)
122        self.assertEqual(self.s - frozenset(self.items2), i)
123
124    def test_symmetric_difference(self):
125        i = self.s.symmetric_difference(self.items2)
126        for c in self.letters:
127            self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
128        self.assertEqual(self.s, WeakSet(self.items))
129        self.assertEqual(type(i), WeakSet)
130        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
131        self.assertEqual(len(i), len(self.items) + len(self.items2))
132        self.items2.pop()
133        gc.collect()
134        self.assertEqual(len(i), len(self.items) + len(self.items2))
135
136    def test_xor(self):
137        i = self.s.symmetric_difference(self.items2)
138        self.assertEqual(self.s ^ set(self.items2), i)
139        self.assertEqual(self.s ^ frozenset(self.items2), i)
140
141    def test_sub_and_super(self):
142        self.assertTrue(self.ab_weakset <= self.abcde_weakset)
143        self.assertTrue(self.abcde_weakset <= self.abcde_weakset)
144        self.assertTrue(self.abcde_weakset >= self.ab_weakset)
145        self.assertFalse(self.abcde_weakset <= self.def_weakset)
146        self.assertFalse(self.abcde_weakset >= self.def_weakset)
147        self.assertTrue(set('a').issubset('abc'))
148        self.assertTrue(set('abc').issuperset('a'))
149        self.assertFalse(set('a').issubset('cbs'))
150        self.assertFalse(set('cbs').issuperset('a'))
151
152    def test_lt(self):
153        self.assertTrue(self.ab_weakset < self.abcde_weakset)
154        self.assertFalse(self.abcde_weakset < self.def_weakset)
155        self.assertFalse(self.ab_weakset < self.ab_weakset)
156        self.assertFalse(WeakSet() < WeakSet())
157
158    def test_gt(self):
159        self.assertTrue(self.abcde_weakset > self.ab_weakset)
160        self.assertFalse(self.abcde_weakset > self.def_weakset)
161        self.assertFalse(self.ab_weakset > self.ab_weakset)
162        self.assertFalse(WeakSet() > WeakSet())
163
164    def test_gc(self):
165        # Create a nest of cycles to exercise overall ref count check
166        s = WeakSet(Foo() for i in range(1000))
167        for elem in s:
168            elem.cycle = s
169            elem.sub = elem
170            elem.set = WeakSet([elem])
171
172    def test_subclass_with_custom_hash(self):
173        # Bug #1257731
174        class H(WeakSet):
175            def __hash__(self):
176                return int(id(self) & 0x7fffffff)
177        s=H()
178        f=set()
179        f.add(s)
180        self.assertIn(s, f)
181        f.remove(s)
182        f.add(s)
183        f.discard(s)
184
185    def test_init(self):
186        s = WeakSet()
187        s.__init__(self.items)
188        self.assertEqual(s, self.s)
189        s.__init__(self.items2)
190        self.assertEqual(s, WeakSet(self.items2))
191        self.assertRaises(TypeError, s.__init__, s, 2);
192        self.assertRaises(TypeError, s.__init__, 1);
193
194    def test_constructor_identity(self):
195        s = WeakSet(self.items)
196        t = WeakSet(s)
197        self.assertNotEqual(id(s), id(t))
198
199    def test_hash(self):
200        self.assertRaises(TypeError, hash, self.s)
201
202    def test_clear(self):
203        self.s.clear()
204        self.assertEqual(self.s, WeakSet([]))
205        self.assertEqual(len(self.s), 0)
206
207    def test_copy(self):
208        dup = self.s.copy()
209        self.assertEqual(self.s, dup)
210        self.assertNotEqual(id(self.s), id(dup))
211
212    def test_add(self):
213        x = ustr('Q')
214        self.s.add(x)
215        self.assertIn(x, self.s)
216        dup = self.s.copy()
217        self.s.add(x)
218        self.assertEqual(self.s, dup)
219        self.assertRaises(TypeError, self.s.add, [])
220        self.fs.add(Foo())
221        support.gc_collect()  # For PyPy or other GCs.
222        self.assertTrue(len(self.fs) == 1)
223        self.fs.add(self.obj)
224        self.assertTrue(len(self.fs) == 1)
225
226    def test_remove(self):
227        x = ustr('a')
228        self.s.remove(x)
229        self.assertNotIn(x, self.s)
230        self.assertRaises(KeyError, self.s.remove, x)
231        self.assertRaises(TypeError, self.s.remove, [])
232
233    def test_discard(self):
234        a, q = ustr('a'), ustr('Q')
235        self.s.discard(a)
236        self.assertNotIn(a, self.s)
237        self.s.discard(q)
238        self.assertRaises(TypeError, self.s.discard, [])
239
240    def test_pop(self):
241        for i in range(len(self.s)):
242            elem = self.s.pop()
243            self.assertNotIn(elem, self.s)
244        self.assertRaises(KeyError, self.s.pop)
245
246    def test_update(self):
247        retval = self.s.update(self.items2)
248        self.assertEqual(retval, None)
249        for c in (self.items + self.items2):
250            self.assertIn(c, self.s)
251        self.assertRaises(TypeError, self.s.update, [[]])
252
253    def test_update_set(self):
254        self.s.update(set(self.items2))
255        for c in (self.items + self.items2):
256            self.assertIn(c, self.s)
257
258    def test_ior(self):
259        self.s |= set(self.items2)
260        for c in (self.items + self.items2):
261            self.assertIn(c, self.s)
262
263    def test_intersection_update(self):
264        retval = self.s.intersection_update(self.items2)
265        self.assertEqual(retval, None)
266        for c in (self.items + self.items2):
267            if c in self.items2 and c in self.items:
268                self.assertIn(c, self.s)
269            else:
270                self.assertNotIn(c, self.s)
271        self.assertRaises(TypeError, self.s.intersection_update, [[]])
272
273    def test_iand(self):
274        self.s &= set(self.items2)
275        for c in (self.items + self.items2):
276            if c in self.items2 and c in self.items:
277                self.assertIn(c, self.s)
278            else:
279                self.assertNotIn(c, self.s)
280
281    def test_difference_update(self):
282        retval = self.s.difference_update(self.items2)
283        self.assertEqual(retval, None)
284        for c in (self.items + self.items2):
285            if c in self.items and c not in self.items2:
286                self.assertIn(c, self.s)
287            else:
288                self.assertNotIn(c, self.s)
289        self.assertRaises(TypeError, self.s.difference_update, [[]])
290        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
291
292    def test_isub(self):
293        self.s -= set(self.items2)
294        for c in (self.items + self.items2):
295            if c in self.items and c not in self.items2:
296                self.assertIn(c, self.s)
297            else:
298                self.assertNotIn(c, self.s)
299
300    def test_symmetric_difference_update(self):
301        retval = self.s.symmetric_difference_update(self.items2)
302        self.assertEqual(retval, None)
303        for c in (self.items + self.items2):
304            if (c in self.items) ^ (c in self.items2):
305                self.assertIn(c, self.s)
306            else:
307                self.assertNotIn(c, self.s)
308        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
309
310    def test_ixor(self):
311        self.s ^= set(self.items2)
312        for c in (self.items + self.items2):
313            if (c in self.items) ^ (c in self.items2):
314                self.assertIn(c, self.s)
315            else:
316                self.assertNotIn(c, self.s)
317
318    def test_inplace_on_self(self):
319        t = self.s.copy()
320        t |= t
321        self.assertEqual(t, self.s)
322        t &= t
323        self.assertEqual(t, self.s)
324        t -= t
325        self.assertEqual(t, WeakSet())
326        t = self.s.copy()
327        t ^= t
328        self.assertEqual(t, WeakSet())
329
330    def test_eq(self):
331        # issue 5964
332        self.assertTrue(self.s == self.s)
333        self.assertTrue(self.s == WeakSet(self.items))
334        self.assertFalse(self.s == set(self.items))
335        self.assertFalse(self.s == list(self.items))
336        self.assertFalse(self.s == tuple(self.items))
337        self.assertFalse(self.s == WeakSet([Foo]))
338        self.assertFalse(self.s == 1)
339
340    def test_ne(self):
341        self.assertTrue(self.s != set(self.items))
342        s1 = WeakSet()
343        s2 = WeakSet()
344        self.assertFalse(s1 != s2)
345
346    def test_weak_destroy_while_iterating(self):
347        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
348        # Create new items to be sure no-one else holds a reference
349        items = [ustr(c) for c in ('a', 'b', 'c')]
350        s = WeakSet(items)
351        it = iter(s)
352        next(it)             # Trigger internal iteration
353        # Destroy an item
354        del items[-1]
355        gc.collect()    # just in case
356        # We have removed either the first consumed items, or another one
357        self.assertIn(len(list(it)), [len(items), len(items) - 1])
358        del it
359        # The removal has been committed
360        self.assertEqual(len(s), len(items))
361
362    def test_weak_destroy_and_mutate_while_iterating(self):
363        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
364        items = [ustr(c) for c in string.ascii_letters]
365        s = WeakSet(items)
366        @contextlib.contextmanager
367        def testcontext():
368            try:
369                it = iter(s)
370                # Start iterator
371                yielded = ustr(str(next(it)))
372                # Schedule an item for removal and recreate it
373                u = ustr(str(items.pop()))
374                if yielded == u:
375                    # The iterator still has a reference to the removed item,
376                    # advance it (issue #20006).
377                    next(it)
378                gc.collect()      # just in case
379                yield u
380            finally:
381                it = None           # should commit all removals
382
383        with testcontext() as u:
384            self.assertNotIn(u, s)
385        with testcontext() as u:
386            self.assertRaises(KeyError, s.remove, u)
387        self.assertNotIn(u, s)
388        with testcontext() as u:
389            s.add(u)
390        self.assertIn(u, s)
391        t = s.copy()
392        with testcontext() as u:
393            s.update(t)
394        self.assertEqual(len(s), len(t))
395        with testcontext() as u:
396            s.clear()
397        self.assertEqual(len(s), 0)
398
399    def test_len_cycles(self):
400        N = 20
401        items = [RefCycle() for i in range(N)]
402        s = WeakSet(items)
403        del items
404        it = iter(s)
405        try:
406            next(it)
407        except StopIteration:
408            pass
409        gc.collect()
410        n1 = len(s)
411        del it
412        gc.collect()
413        gc.collect()  # For PyPy or other GCs.
414        n2 = len(s)
415        # one item may be kept alive inside the iterator
416        self.assertIn(n1, (0, 1))
417        self.assertEqual(n2, 0)
418
419    def test_len_race(self):
420        # Extended sanity checks for len() in the face of cyclic collection
421        self.addCleanup(gc.set_threshold, *gc.get_threshold())
422        for th in range(1, 100):
423            N = 20
424            gc.collect(0)
425            gc.set_threshold(th, th, th)
426            items = [RefCycle() for i in range(N)]
427            s = WeakSet(items)
428            del items
429            # All items will be collected at next garbage collection pass
430            it = iter(s)
431            try:
432                next(it)
433            except StopIteration:
434                pass
435            n1 = len(s)
436            del it
437            n2 = len(s)
438            self.assertGreaterEqual(n1, 0)
439            self.assertLessEqual(n1, N)
440            self.assertGreaterEqual(n2, 0)
441            self.assertLessEqual(n2, n1)
442
443    def test_repr(self):
444        assert repr(self.s) == repr(self.s.data)
445
446    def test_abc(self):
447        self.assertIsInstance(self.s, Set)
448        self.assertIsInstance(self.s, MutableSet)
449
450
451if __name__ == "__main__":
452    unittest.main()
453