1"""Utilities for collecting objects based on "is" comparison.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16import weakref 17 18from tensorflow.python.util.compat import collections_abc 19 20 21# LINT.IfChange 22class _ObjectIdentityWrapper: 23 """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. 24 25 Since __eq__ is based on object identity, it's safe to also define __hash__ 26 based on object ids. This lets us add unhashable types like trackable 27 _ListWrapper objects to object-identity collections. 28 """ 29 30 __slots__ = ["_wrapped", "__weakref__"] 31 32 def __init__(self, wrapped): 33 self._wrapped = wrapped 34 35 @property 36 def unwrapped(self): 37 return self._wrapped 38 39 def _assert_type(self, other): 40 if not isinstance(other, _ObjectIdentityWrapper): 41 raise TypeError("Cannot compare wrapped object with unwrapped object") 42 43 def __lt__(self, other): 44 self._assert_type(other) 45 return id(self._wrapped) < id(other._wrapped) # pylint: disable=protected-access 46 47 def __gt__(self, other): 48 self._assert_type(other) 49 return id(self._wrapped) > id(other._wrapped) # pylint: disable=protected-access 50 51 def __eq__(self, other): 52 if other is None: 53 return False 54 self._assert_type(other) 55 return self._wrapped is other._wrapped # pylint: disable=protected-access 56 57 def __ne__(self, other): 58 return not self.__eq__(other) 59 60 def __hash__(self): 61 # Wrapper id() is also fine for weakrefs. In fact, we rely on 62 # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is 63 # weakref.ref(a) in _WeakObjectIdentityWrapper. 64 return id(self._wrapped) 65 66 def __repr__(self): 67 return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped) 68 69 70class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): 71 72 __slots__ = () 73 74 def __init__(self, wrapped): 75 super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) 76 77 @property 78 def unwrapped(self): 79 return self._wrapped() 80 81 82class Reference(_ObjectIdentityWrapper): 83 """Reference that refers an object. 84 85 ```python 86 x = [1] 87 y = [1] 88 89 x_ref1 = Reference(x) 90 x_ref2 = Reference(x) 91 y_ref2 = Reference(y) 92 93 print(x_ref1 == x_ref2) 94 ==> True 95 96 print(x_ref1 == y) 97 ==> False 98 ``` 99 """ 100 101 __slots__ = () 102 103 # Disabling super class' unwrapped field. 104 unwrapped = property() 105 106 def deref(self): 107 """Returns the referenced object. 108 109 ```python 110 x_ref = Reference(x) 111 print(x is x_ref.deref()) 112 ==> True 113 ``` 114 """ 115 return self._wrapped 116 117 118class ObjectIdentityDictionary(collections_abc.MutableMapping): 119 """A mutable mapping data structure which compares using "is". 120 121 This is necessary because we have trackable objects (_ListWrapper) which 122 have behavior identical to built-in Python lists (including being unhashable 123 and comparing based on the equality of their contents by default). 124 """ 125 126 __slots__ = ["_storage"] 127 128 def __init__(self): 129 self._storage = {} 130 131 def _wrap_key(self, key): 132 return _ObjectIdentityWrapper(key) 133 134 def __getitem__(self, key): 135 return self._storage[self._wrap_key(key)] 136 137 def __setitem__(self, key, value): 138 self._storage[self._wrap_key(key)] = value 139 140 def __delitem__(self, key): 141 del self._storage[self._wrap_key(key)] 142 143 def __len__(self): 144 return len(self._storage) 145 146 def __iter__(self): 147 for key in self._storage: 148 yield key.unwrapped 149 150 def __repr__(self): 151 return "ObjectIdentityDictionary(%s)" % repr(self._storage) 152 153 154class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): 155 """Like weakref.WeakKeyDictionary, but compares objects with "is".""" 156 157 __slots__ = ["__weakref__"] 158 159 def _wrap_key(self, key): 160 return _WeakObjectIdentityWrapper(key) 161 162 def __len__(self): 163 # Iterate, discarding old weak refs 164 return len(list(self._storage)) 165 166 def __iter__(self): 167 keys = self._storage.keys() 168 for key in keys: 169 unwrapped = key.unwrapped 170 if unwrapped is None: 171 del self[key] 172 else: 173 yield unwrapped 174 175 176class ObjectIdentitySet(collections_abc.MutableSet): 177 """Like the built-in set, but compares objects with "is".""" 178 179 __slots__ = ["_storage", "__weakref__"] 180 181 def __init__(self, *args): 182 self._storage = set(self._wrap_key(obj) for obj in list(*args)) 183 184 @staticmethod 185 def _from_storage(storage): 186 result = ObjectIdentitySet() 187 result._storage = storage # pylint: disable=protected-access 188 return result 189 190 def _wrap_key(self, key): 191 return _ObjectIdentityWrapper(key) 192 193 def __contains__(self, key): 194 return self._wrap_key(key) in self._storage 195 196 def discard(self, key): 197 self._storage.discard(self._wrap_key(key)) 198 199 def add(self, key): 200 self._storage.add(self._wrap_key(key)) 201 202 def update(self, items): 203 self._storage.update([self._wrap_key(item) for item in items]) 204 205 def clear(self): 206 self._storage.clear() 207 208 def intersection(self, items): 209 return self._storage.intersection([self._wrap_key(item) for item in items]) 210 211 def difference(self, items): 212 return ObjectIdentitySet._from_storage( 213 self._storage.difference([self._wrap_key(item) for item in items])) 214 215 def __len__(self): 216 return len(self._storage) 217 218 def __iter__(self): 219 keys = list(self._storage) 220 for key in keys: 221 yield key.unwrapped 222 223 224class ObjectIdentityWeakSet(ObjectIdentitySet): 225 """Like weakref.WeakSet, but compares objects with "is".""" 226 227 __slots__ = () 228 229 def _wrap_key(self, key): 230 return _WeakObjectIdentityWrapper(key) 231 232 def __len__(self): 233 # Iterate, discarding old weak refs 234 return len([_ for _ in self]) 235 236 def __iter__(self): 237 keys = list(self._storage) 238 for key in keys: 239 unwrapped = key.unwrapped 240 if unwrapped is None: 241 self.discard(key) 242 else: 243 yield unwrapped 244# LINT.ThenChange(//tensorflow/python/keras/utils/object_identity.py) 245