• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Access WeakSet through the weakref module.
2# This code is separated-out because it is needed
3# by abc.py to load everything else at startup.
4
5from _weakref import ref
6
7__all__ = ['WeakSet']
8
9
10class _IterationGuard(object):
11    # This context manager registers itself in the current iterators of the
12    # weak container, such as to delay all removals until the context manager
13    # exits.
14    # This technique should be relatively thread-safe (since sets are).
15
16    def __init__(self, weakcontainer):
17        # Don't create cycles
18        self.weakcontainer = ref(weakcontainer)
19
20    def __enter__(self):
21        w = self.weakcontainer()
22        if w is not None:
23            w._iterating.add(self)
24        return self
25
26    def __exit__(self, e, t, b):
27        w = self.weakcontainer()
28        if w is not None:
29            s = w._iterating
30            s.remove(self)
31            if not s:
32                w._commit_removals()
33
34
35class WeakSet(object):
36    def __init__(self, data=None):
37        self.data = set()
38        def _remove(item, selfref=ref(self)):
39            self = selfref()
40            if self is not None:
41                if self._iterating:
42                    self._pending_removals.append(item)
43                else:
44                    self.data.discard(item)
45        self._remove = _remove
46        # A list of keys to be removed
47        self._pending_removals = []
48        self._iterating = set()
49        if data is not None:
50            self.update(data)
51
52    def _commit_removals(self):
53        l = self._pending_removals
54        discard = self.data.discard
55        while l:
56            discard(l.pop())
57
58    def __iter__(self):
59        with _IterationGuard(self):
60            for itemref in self.data:
61                item = itemref()
62                if item is not None:
63                    # Caveat: the iterator will keep a strong reference to
64                    # `item` until it is resumed or closed.
65                    yield item
66
67    def __len__(self):
68        return len(self.data) - len(self._pending_removals)
69
70    def __contains__(self, item):
71        try:
72            wr = ref(item)
73        except TypeError:
74            return False
75        return wr in self.data
76
77    def __reduce__(self):
78        return (self.__class__, (list(self),),
79                getattr(self, '__dict__', None))
80
81    __hash__ = None
82
83    def add(self, item):
84        if self._pending_removals:
85            self._commit_removals()
86        self.data.add(ref(item, self._remove))
87
88    def clear(self):
89        if self._pending_removals:
90            self._commit_removals()
91        self.data.clear()
92
93    def copy(self):
94        return self.__class__(self)
95
96    def pop(self):
97        if self._pending_removals:
98            self._commit_removals()
99        while True:
100            try:
101                itemref = self.data.pop()
102            except KeyError:
103                raise KeyError('pop from empty WeakSet')
104            item = itemref()
105            if item is not None:
106                return item
107
108    def remove(self, item):
109        if self._pending_removals:
110            self._commit_removals()
111        self.data.remove(ref(item))
112
113    def discard(self, item):
114        if self._pending_removals:
115            self._commit_removals()
116        self.data.discard(ref(item))
117
118    def update(self, other):
119        if self._pending_removals:
120            self._commit_removals()
121        for element in other:
122            self.add(element)
123
124    def __ior__(self, other):
125        self.update(other)
126        return self
127
128    def difference(self, other):
129        newset = self.copy()
130        newset.difference_update(other)
131        return newset
132    __sub__ = difference
133
134    def difference_update(self, other):
135        self.__isub__(other)
136    def __isub__(self, other):
137        if self._pending_removals:
138            self._commit_removals()
139        if self is other:
140            self.data.clear()
141        else:
142            self.data.difference_update(ref(item) for item in other)
143        return self
144
145    def intersection(self, other):
146        return self.__class__(item for item in other if item in self)
147    __and__ = intersection
148
149    def intersection_update(self, other):
150        self.__iand__(other)
151    def __iand__(self, other):
152        if self._pending_removals:
153            self._commit_removals()
154        self.data.intersection_update(ref(item) for item in other)
155        return self
156
157    def issubset(self, other):
158        return self.data.issubset(ref(item) for item in other)
159    __le__ = issubset
160
161    def __lt__(self, other):
162        return self.data < set(ref(item) for item in other)
163
164    def issuperset(self, other):
165        return self.data.issuperset(ref(item) for item in other)
166    __ge__ = issuperset
167
168    def __gt__(self, other):
169        return self.data > set(ref(item) for item in other)
170
171    def __eq__(self, other):
172        if not isinstance(other, self.__class__):
173            return NotImplemented
174        return self.data == set(ref(item) for item in other)
175
176    def __ne__(self, other):
177        opposite = self.__eq__(other)
178        if opposite is NotImplemented:
179            return NotImplemented
180        return not opposite
181
182    def symmetric_difference(self, other):
183        newset = self.copy()
184        newset.symmetric_difference_update(other)
185        return newset
186    __xor__ = symmetric_difference
187
188    def symmetric_difference_update(self, other):
189        self.__ixor__(other)
190    def __ixor__(self, other):
191        if self._pending_removals:
192            self._commit_removals()
193        if self is other:
194            self.data.clear()
195        else:
196            self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
197        return self
198
199    def union(self, other):
200        return self.__class__(e for s in (self, other) for e in s)
201    __or__ = union
202
203    def isdisjoint(self, other):
204        return len(self.intersection(other)) == 0
205