1"""Utilities for collecting objects based on "is" comparison.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import weakref 21 22from tensorflow.python.util.compat import collections_abc 23 24 25# LINT.IfChange 26class _ObjectIdentityWrapper: 27 """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. 28 29 Since __eq__ is based on object identity, it's safe to also define __hash__ 30 based on object ids. This lets us add unhashable types like trackable 31 _ListWrapper objects to object-identity collections. 32 """ 33 34 __slots__ = ["_wrapped", "__weakref__"] 35 36 def __init__(self, wrapped): 37 self._wrapped = wrapped 38 39 @property 40 def unwrapped(self): 41 return self._wrapped 42 43 def _assert_type(self, other): 44 if not isinstance(other, _ObjectIdentityWrapper): 45 raise TypeError("Cannot compare wrapped object with unwrapped object") 46 47 def __lt__(self, other): 48 self._assert_type(other) 49 return id(self._wrapped) < id(other._wrapped) # pylint: disable=protected-access 50 51 def __gt__(self, other): 52 self._assert_type(other) 53 return id(self._wrapped) > id(other._wrapped) # pylint: disable=protected-access 54 55 def __eq__(self, other): 56 if other is None: 57 return False 58 self._assert_type(other) 59 return self._wrapped is other._wrapped # pylint: disable=protected-access 60 61 def __ne__(self, other): 62 return not self.__eq__(other) 63 64 def __hash__(self): 65 # Wrapper id() is also fine for weakrefs. In fact, we rely on 66 # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is 67 # weakref.ref(a) in _WeakObjectIdentityWrapper. 68 return id(self._wrapped) 69 70 def __repr__(self): 71 return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped) 72 73 74class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): 75 76 __slots__ = () 77 78 def __init__(self, wrapped): 79 super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) 80 81 @property 82 def unwrapped(self): 83 return self._wrapped() 84 85 86class Reference(_ObjectIdentityWrapper): 87 """Reference that refers an object. 88 89 ```python 90 x = [1] 91 y = [1] 92 93 x_ref1 = Reference(x) 94 x_ref2 = Reference(x) 95 y_ref2 = Reference(y) 96 97 print(x_ref1 == x_ref2) 98 ==> True 99 100 print(x_ref1 == y) 101 ==> False 102 ``` 103 """ 104 105 __slots__ = () 106 107 # Disabling super class' unwrapped field. 108 unwrapped = property() 109 110 def deref(self): 111 """Returns the referenced object. 112 113 ```python 114 x_ref = Reference(x) 115 print(x is x_ref.deref()) 116 ==> True 117 ``` 118 """ 119 return self._wrapped 120 121 122class ObjectIdentityDictionary(collections_abc.MutableMapping): 123 """A mutable mapping data structure which compares using "is". 124 125 This is necessary because we have trackable objects (_ListWrapper) which 126 have behavior identical to built-in Python lists (including being unhashable 127 and comparing based on the equality of their contents by default). 128 """ 129 130 __slots__ = ["_storage"] 131 132 def __init__(self): 133 self._storage = {} 134 135 def _wrap_key(self, key): 136 return _ObjectIdentityWrapper(key) 137 138 def __getitem__(self, key): 139 return self._storage[self._wrap_key(key)] 140 141 def __setitem__(self, key, value): 142 self._storage[self._wrap_key(key)] = value 143 144 def __delitem__(self, key): 145 del self._storage[self._wrap_key(key)] 146 147 def __len__(self): 148 return len(self._storage) 149 150 def __iter__(self): 151 for key in self._storage: 152 yield key.unwrapped 153 154 def __repr__(self): 155 return "ObjectIdentityDictionary(%s)" % repr(self._storage) 156 157 158class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): 159 """Like weakref.WeakKeyDictionary, but compares objects with "is".""" 160 161 __slots__ = ["__weakref__"] 162 163 def _wrap_key(self, key): 164 return _WeakObjectIdentityWrapper(key) 165 166 def __len__(self): 167 # Iterate, discarding old weak refs 168 return len(list(self._storage)) 169 170 def __iter__(self): 171 keys = self._storage.keys() 172 for key in keys: 173 unwrapped = key.unwrapped 174 if unwrapped is None: 175 del self[key] 176 else: 177 yield unwrapped 178 179 180class ObjectIdentitySet(collections_abc.MutableSet): 181 """Like the built-in set, but compares objects with "is".""" 182 183 __slots__ = ["_storage", "__weakref__"] 184 185 def __init__(self, *args): 186 self._storage = set(self._wrap_key(obj) for obj in list(*args)) 187 188 @staticmethod 189 def _from_storage(storage): 190 result = ObjectIdentitySet() 191 result._storage = storage # pylint: disable=protected-access 192 return result 193 194 def _wrap_key(self, key): 195 return _ObjectIdentityWrapper(key) 196 197 def __contains__(self, key): 198 return self._wrap_key(key) in self._storage 199 200 def discard(self, key): 201 self._storage.discard(self._wrap_key(key)) 202 203 def add(self, key): 204 self._storage.add(self._wrap_key(key)) 205 206 def update(self, items): 207 self._storage.update([self._wrap_key(item) for item in items]) 208 209 def clear(self): 210 self._storage.clear() 211 212 def intersection(self, items): 213 return self._storage.intersection([self._wrap_key(item) for item in items]) 214 215 def difference(self, items): 216 return ObjectIdentitySet._from_storage( 217 self._storage.difference([self._wrap_key(item) for item in items])) 218 219 def __len__(self): 220 return len(self._storage) 221 222 def __iter__(self): 223 keys = list(self._storage) 224 for key in keys: 225 yield key.unwrapped 226 227 228class ObjectIdentityWeakSet(ObjectIdentitySet): 229 """Like weakref.WeakSet, but compares objects with "is".""" 230 231 __slots__ = () 232 233 def _wrap_key(self, key): 234 return _WeakObjectIdentityWrapper(key) 235 236 def __len__(self): 237 # Iterate, discarding old weak refs 238 return len([_ for _ in self]) 239 240 def __iter__(self): 241 keys = list(self._storage) 242 for key in keys: 243 unwrapped = key.unwrapped 244 if unwrapped is None: 245 self.discard(key) 246 else: 247 yield unwrapped 248# LINT.ThenChange(//tensorflow/python/keras/utils/object_identity.py) 249