• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities for collecting objects based on "is" comparison."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import weakref
21
22from tensorflow.python.util.compat import collections_abc
23
24
25# LINT.IfChange
26class _ObjectIdentityWrapper:
27  """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
28
29  Since __eq__ is based on object identity, it's safe to also define __hash__
30  based on object ids. This lets us add unhashable types like trackable
31  _ListWrapper objects to object-identity collections.
32  """
33
34  __slots__ = ["_wrapped", "__weakref__"]
35
36  def __init__(self, wrapped):
37    self._wrapped = wrapped
38
39  @property
40  def unwrapped(self):
41    return self._wrapped
42
43  def _assert_type(self, other):
44    if not isinstance(other, _ObjectIdentityWrapper):
45      raise TypeError("Cannot compare wrapped object with unwrapped object")
46
47  def __lt__(self, other):
48    self._assert_type(other)
49    return id(self._wrapped) < id(other._wrapped)  # pylint: disable=protected-access
50
51  def __gt__(self, other):
52    self._assert_type(other)
53    return id(self._wrapped) > id(other._wrapped)  # pylint: disable=protected-access
54
55  def __eq__(self, other):
56    if other is None:
57      return False
58    self._assert_type(other)
59    return self._wrapped is other._wrapped  # pylint: disable=protected-access
60
61  def __ne__(self, other):
62    return not self.__eq__(other)
63
64  def __hash__(self):
65    # Wrapper id() is also fine for weakrefs. In fact, we rely on
66    # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
67    # weakref.ref(a) in _WeakObjectIdentityWrapper.
68    return id(self._wrapped)
69
70  def __repr__(self):
71    return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped)
72
73
74class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
75
76  __slots__ = ()
77
78  def __init__(self, wrapped):
79    super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
80
81  @property
82  def unwrapped(self):
83    return self._wrapped()
84
85
86class Reference(_ObjectIdentityWrapper):
87  """Reference that refers an object.
88
89  ```python
90  x = [1]
91  y = [1]
92
93  x_ref1 = Reference(x)
94  x_ref2 = Reference(x)
95  y_ref2 = Reference(y)
96
97  print(x_ref1 == x_ref2)
98  ==> True
99
100  print(x_ref1 == y)
101  ==> False
102  ```
103  """
104
105  __slots__ = ()
106
107  # Disabling super class' unwrapped field.
108  unwrapped = property()
109
110  def deref(self):
111    """Returns the referenced object.
112
113    ```python
114    x_ref = Reference(x)
115    print(x is x_ref.deref())
116    ==> True
117    ```
118    """
119    return self._wrapped
120
121
122class ObjectIdentityDictionary(collections_abc.MutableMapping):
123  """A mutable mapping data structure which compares using "is".
124
125  This is necessary because we have trackable objects (_ListWrapper) which
126  have behavior identical to built-in Python lists (including being unhashable
127  and comparing based on the equality of their contents by default).
128  """
129
130  __slots__ = ["_storage"]
131
132  def __init__(self):
133    self._storage = {}
134
135  def _wrap_key(self, key):
136    return _ObjectIdentityWrapper(key)
137
138  def __getitem__(self, key):
139    return self._storage[self._wrap_key(key)]
140
141  def __setitem__(self, key, value):
142    self._storage[self._wrap_key(key)] = value
143
144  def __delitem__(self, key):
145    del self._storage[self._wrap_key(key)]
146
147  def __len__(self):
148    return len(self._storage)
149
150  def __iter__(self):
151    for key in self._storage:
152      yield key.unwrapped
153
154  def __repr__(self):
155    return "ObjectIdentityDictionary(%s)" % repr(self._storage)
156
157
158class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
159  """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
160
161  __slots__ = ["__weakref__"]
162
163  def _wrap_key(self, key):
164    return _WeakObjectIdentityWrapper(key)
165
166  def __len__(self):
167    # Iterate, discarding old weak refs
168    return len(list(self._storage))
169
170  def __iter__(self):
171    keys = self._storage.keys()
172    for key in keys:
173      unwrapped = key.unwrapped
174      if unwrapped is None:
175        del self[key]
176      else:
177        yield unwrapped
178
179
180class ObjectIdentitySet(collections_abc.MutableSet):
181  """Like the built-in set, but compares objects with "is"."""
182
183  __slots__ = ["_storage", "__weakref__"]
184
185  def __init__(self, *args):
186    self._storage = set(self._wrap_key(obj) for obj in list(*args))
187
188  @staticmethod
189  def _from_storage(storage):
190    result = ObjectIdentitySet()
191    result._storage = storage  # pylint: disable=protected-access
192    return result
193
194  def _wrap_key(self, key):
195    return _ObjectIdentityWrapper(key)
196
197  def __contains__(self, key):
198    return self._wrap_key(key) in self._storage
199
200  def discard(self, key):
201    self._storage.discard(self._wrap_key(key))
202
203  def add(self, key):
204    self._storage.add(self._wrap_key(key))
205
206  def update(self, items):
207    self._storage.update([self._wrap_key(item) for item in items])
208
209  def clear(self):
210    self._storage.clear()
211
212  def intersection(self, items):
213    return self._storage.intersection([self._wrap_key(item) for item in items])
214
215  def difference(self, items):
216    return ObjectIdentitySet._from_storage(
217        self._storage.difference([self._wrap_key(item) for item in items]))
218
219  def __len__(self):
220    return len(self._storage)
221
222  def __iter__(self):
223    keys = list(self._storage)
224    for key in keys:
225      yield key.unwrapped
226
227
228class ObjectIdentityWeakSet(ObjectIdentitySet):
229  """Like weakref.WeakSet, but compares objects with "is"."""
230
231  __slots__ = ()
232
233  def _wrap_key(self, key):
234    return _WeakObjectIdentityWrapper(key)
235
236  def __len__(self):
237    # Iterate, discarding old weak refs
238    return len([_ for _ in self])
239
240  def __iter__(self):
241    keys = list(self._storage)
242    for key in keys:
243      unwrapped = key.unwrapped
244      if unwrapped is None:
245        self.discard(key)
246      else:
247        yield unwrapped
248# LINT.ThenChange(//tensorflow/python/keras/utils/object_identity.py)
249