• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities for collecting objects based on "is" comparison."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import weakref
17
18from tensorflow.python.util.compat import collections_abc
19
20
21# LINT.IfChange
22class _ObjectIdentityWrapper:
23  """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
24
25  Since __eq__ is based on object identity, it's safe to also define __hash__
26  based on object ids. This lets us add unhashable types like trackable
27  _ListWrapper objects to object-identity collections.
28  """
29
30  __slots__ = ["_wrapped", "__weakref__"]
31
32  def __init__(self, wrapped):
33    self._wrapped = wrapped
34
35  @property
36  def unwrapped(self):
37    return self._wrapped
38
39  def _assert_type(self, other):
40    if not isinstance(other, _ObjectIdentityWrapper):
41      raise TypeError("Cannot compare wrapped object with unwrapped object")
42
43  def __lt__(self, other):
44    self._assert_type(other)
45    return id(self._wrapped) < id(other._wrapped)  # pylint: disable=protected-access
46
47  def __gt__(self, other):
48    self._assert_type(other)
49    return id(self._wrapped) > id(other._wrapped)  # pylint: disable=protected-access
50
51  def __eq__(self, other):
52    if other is None:
53      return False
54    self._assert_type(other)
55    return self._wrapped is other._wrapped  # pylint: disable=protected-access
56
57  def __ne__(self, other):
58    return not self.__eq__(other)
59
60  def __hash__(self):
61    # Wrapper id() is also fine for weakrefs. In fact, we rely on
62    # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
63    # weakref.ref(a) in _WeakObjectIdentityWrapper.
64    return id(self._wrapped)
65
66  def __repr__(self):
67    return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped)
68
69
70class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
71
72  __slots__ = ()
73
74  def __init__(self, wrapped):
75    super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
76
77  @property
78  def unwrapped(self):
79    return self._wrapped()
80
81
82class Reference(_ObjectIdentityWrapper):
83  """Reference that refers an object.
84
85  ```python
86  x = [1]
87  y = [1]
88
89  x_ref1 = Reference(x)
90  x_ref2 = Reference(x)
91  y_ref2 = Reference(y)
92
93  print(x_ref1 == x_ref2)
94  ==> True
95
96  print(x_ref1 == y)
97  ==> False
98  ```
99  """
100
101  __slots__ = ()
102
103  # Disabling super class' unwrapped field.
104  unwrapped = property()
105
106  def deref(self):
107    """Returns the referenced object.
108
109    ```python
110    x_ref = Reference(x)
111    print(x is x_ref.deref())
112    ==> True
113    ```
114    """
115    return self._wrapped
116
117
118class ObjectIdentityDictionary(collections_abc.MutableMapping):
119  """A mutable mapping data structure which compares using "is".
120
121  This is necessary because we have trackable objects (_ListWrapper) which
122  have behavior identical to built-in Python lists (including being unhashable
123  and comparing based on the equality of their contents by default).
124  """
125
126  __slots__ = ["_storage"]
127
128  def __init__(self):
129    self._storage = {}
130
131  def _wrap_key(self, key):
132    return _ObjectIdentityWrapper(key)
133
134  def __getitem__(self, key):
135    return self._storage[self._wrap_key(key)]
136
137  def __setitem__(self, key, value):
138    self._storage[self._wrap_key(key)] = value
139
140  def __delitem__(self, key):
141    del self._storage[self._wrap_key(key)]
142
143  def __len__(self):
144    return len(self._storage)
145
146  def __iter__(self):
147    for key in self._storage:
148      yield key.unwrapped
149
150  def __repr__(self):
151    return "ObjectIdentityDictionary(%s)" % repr(self._storage)
152
153
154class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
155  """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
156
157  __slots__ = ["__weakref__"]
158
159  def _wrap_key(self, key):
160    return _WeakObjectIdentityWrapper(key)
161
162  def __len__(self):
163    # Iterate, discarding old weak refs
164    return len(list(self._storage))
165
166  def __iter__(self):
167    keys = self._storage.keys()
168    for key in keys:
169      unwrapped = key.unwrapped
170      if unwrapped is None:
171        del self[key]
172      else:
173        yield unwrapped
174
175
176class ObjectIdentitySet(collections_abc.MutableSet):
177  """Like the built-in set, but compares objects with "is"."""
178
179  __slots__ = ["_storage", "__weakref__"]
180
181  def __init__(self, *args):
182    self._storage = set(self._wrap_key(obj) for obj in list(*args))
183
184  @staticmethod
185  def _from_storage(storage):
186    result = ObjectIdentitySet()
187    result._storage = storage  # pylint: disable=protected-access
188    return result
189
190  def _wrap_key(self, key):
191    return _ObjectIdentityWrapper(key)
192
193  def __contains__(self, key):
194    return self._wrap_key(key) in self._storage
195
196  def discard(self, key):
197    self._storage.discard(self._wrap_key(key))
198
199  def add(self, key):
200    self._storage.add(self._wrap_key(key))
201
202  def update(self, items):
203    self._storage.update([self._wrap_key(item) for item in items])
204
205  def clear(self):
206    self._storage.clear()
207
208  def intersection(self, items):
209    return self._storage.intersection([self._wrap_key(item) for item in items])
210
211  def difference(self, items):
212    return ObjectIdentitySet._from_storage(
213        self._storage.difference([self._wrap_key(item) for item in items]))
214
215  def __len__(self):
216    return len(self._storage)
217
218  def __iter__(self):
219    keys = list(self._storage)
220    for key in keys:
221      yield key.unwrapped
222
223
224class ObjectIdentityWeakSet(ObjectIdentitySet):
225  """Like weakref.WeakSet, but compares objects with "is"."""
226
227  __slots__ = ()
228
229  def _wrap_key(self, key):
230    return _WeakObjectIdentityWrapper(key)
231
232  def __len__(self):
233    # Iterate, discarding old weak refs
234    return len([_ for _ in self])
235
236  def __iter__(self):
237    keys = list(self._storage)
238    for key in keys:
239      unwrapped = key.unwrapped
240      if unwrapped is None:
241        self.discard(key)
242      else:
243        yield unwrapped
244# LINT.ThenChange(//tensorflow/python/keras/utils/object_identity.py)
245