• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Access WeakSet through the weakref module.
2# This code is separated-out because it is needed
3# by abc.py to load everything else at startup.
4
5from _weakref import ref
6from types import GenericAlias
7
8__all__ = ['WeakSet']
9
10
11class _IterationGuard:
12    # This context manager registers itself in the current iterators of the
13    # weak container, such as to delay all removals until the context manager
14    # exits.
15    # This technique should be relatively thread-safe (since sets are).
16
17    def __init__(self, weakcontainer):
18        # Don't create cycles
19        self.weakcontainer = ref(weakcontainer)
20
21    def __enter__(self):
22        w = self.weakcontainer()
23        if w is not None:
24            w._iterating.add(self)
25        return self
26
27    def __exit__(self, e, t, b):
28        w = self.weakcontainer()
29        if w is not None:
30            s = w._iterating
31            s.remove(self)
32            if not s:
33                w._commit_removals()
34
35
36class WeakSet:
37    def __init__(self, data=None):
38        self.data = set()
39        def _remove(item, selfref=ref(self)):
40            self = selfref()
41            if self is not None:
42                if self._iterating:
43                    self._pending_removals.append(item)
44                else:
45                    self.data.discard(item)
46        self._remove = _remove
47        # A list of keys to be removed
48        self._pending_removals = []
49        self._iterating = set()
50        if data is not None:
51            self.update(data)
52
53    def _commit_removals(self):
54        l = self._pending_removals
55        discard = self.data.discard
56        while l:
57            discard(l.pop())
58
59    def __iter__(self):
60        with _IterationGuard(self):
61            for itemref in self.data:
62                item = itemref()
63                if item is not None:
64                    # Caveat: the iterator will keep a strong reference to
65                    # `item` until it is resumed or closed.
66                    yield item
67
68    def __len__(self):
69        return len(self.data) - len(self._pending_removals)
70
71    def __contains__(self, item):
72        try:
73            wr = ref(item)
74        except TypeError:
75            return False
76        return wr in self.data
77
78    def __reduce__(self):
79        return (self.__class__, (list(self),),
80                getattr(self, '__dict__', None))
81
82    def add(self, item):
83        if self._pending_removals:
84            self._commit_removals()
85        self.data.add(ref(item, self._remove))
86
87    def clear(self):
88        if self._pending_removals:
89            self._commit_removals()
90        self.data.clear()
91
92    def copy(self):
93        return self.__class__(self)
94
95    def pop(self):
96        if self._pending_removals:
97            self._commit_removals()
98        while True:
99            try:
100                itemref = self.data.pop()
101            except KeyError:
102                raise KeyError('pop from empty WeakSet') from None
103            item = itemref()
104            if item is not None:
105                return item
106
107    def remove(self, item):
108        if self._pending_removals:
109            self._commit_removals()
110        self.data.remove(ref(item))
111
112    def discard(self, item):
113        if self._pending_removals:
114            self._commit_removals()
115        self.data.discard(ref(item))
116
117    def update(self, other):
118        if self._pending_removals:
119            self._commit_removals()
120        for element in other:
121            self.add(element)
122
123    def __ior__(self, other):
124        self.update(other)
125        return self
126
127    def difference(self, other):
128        newset = self.copy()
129        newset.difference_update(other)
130        return newset
131    __sub__ = difference
132
133    def difference_update(self, other):
134        self.__isub__(other)
135    def __isub__(self, other):
136        if self._pending_removals:
137            self._commit_removals()
138        if self is other:
139            self.data.clear()
140        else:
141            self.data.difference_update(ref(item) for item in other)
142        return self
143
144    def intersection(self, other):
145        return self.__class__(item for item in other if item in self)
146    __and__ = intersection
147
148    def intersection_update(self, other):
149        self.__iand__(other)
150    def __iand__(self, other):
151        if self._pending_removals:
152            self._commit_removals()
153        self.data.intersection_update(ref(item) for item in other)
154        return self
155
156    def issubset(self, other):
157        return self.data.issubset(ref(item) for item in other)
158    __le__ = issubset
159
160    def __lt__(self, other):
161        return self.data < set(map(ref, other))
162
163    def issuperset(self, other):
164        return self.data.issuperset(ref(item) for item in other)
165    __ge__ = issuperset
166
167    def __gt__(self, other):
168        return self.data > set(map(ref, other))
169
170    def __eq__(self, other):
171        if not isinstance(other, self.__class__):
172            return NotImplemented
173        return self.data == set(map(ref, other))
174
175    def symmetric_difference(self, other):
176        newset = self.copy()
177        newset.symmetric_difference_update(other)
178        return newset
179    __xor__ = symmetric_difference
180
181    def symmetric_difference_update(self, other):
182        self.__ixor__(other)
183    def __ixor__(self, other):
184        if self._pending_removals:
185            self._commit_removals()
186        if self is other:
187            self.data.clear()
188        else:
189            self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
190        return self
191
192    def union(self, other):
193        return self.__class__(e for s in (self, other) for e in s)
194    __or__ = union
195
196    def isdisjoint(self, other):
197        return len(self.intersection(other)) == 0
198
199    def __repr__(self):
200        return repr(self.data)
201
202    __class_getitem__ = classmethod(GenericAlias)
203