1"""Utilities for collecting objects based on "is" comparison.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import weakref 22 23 24class _ObjectIdentityWrapper(object): 25 """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. 26 27 Since __eq__ is based on object identity, it's safe to also define __hash__ 28 based on object ids. This lets us add unhashable types like trackable 29 _ListWrapper objects to object-identity collections. 30 """ 31 32 def __init__(self, wrapped): 33 self._wrapped = wrapped 34 35 @property 36 def unwrapped(self): 37 return self._wrapped 38 39 def __eq__(self, other): 40 if isinstance(other, _ObjectIdentityWrapper): 41 return self._wrapped is other._wrapped # pylint: disable=protected-access 42 return self._wrapped is other 43 44 def __hash__(self): 45 # Wrapper id() is also fine for weakrefs. In fact, we rely on 46 # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is 47 # weakref.ref(a) in _WeakObjectIdentityWrapper. 48 return id(self._wrapped) 49 50 51class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): 52 53 def __init__(self, wrapped): 54 super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) 55 56 @property 57 def unwrapped(self): 58 return self._wrapped() 59 60 61class ObjectIdentityDictionary(collections.MutableMapping): 62 """A mutable mapping data structure which compares using "is". 63 64 This is necessary because we have trackable objects (_ListWrapper) which 65 have behavior identical to built-in Python lists (including being unhashable 66 and comparing based on the equality of their contents by default). 67 """ 68 69 def __init__(self): 70 self._storage = {} 71 72 def _wrap_key(self, key): 73 return _ObjectIdentityWrapper(key) 74 75 def __getitem__(self, key): 76 return self._storage[self._wrap_key(key)] 77 78 def __setitem__(self, key, value): 79 self._storage[self._wrap_key(key)] = value 80 81 def __delitem__(self, key): 82 del self._storage[self._wrap_key(key)] 83 84 def __len__(self): 85 return len(self._storage) 86 87 def __iter__(self): 88 for key in self._storage: 89 yield key.unwrapped 90 91 92class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): 93 """Like weakref.WeakKeyDictionary, but compares objects with "is".""" 94 95 def _wrap_key(self, key): 96 return _WeakObjectIdentityWrapper(key) 97 98 def __len__(self): 99 # Iterate, discarding old weak refs 100 return len(list(self._storage)) 101 102 def __iter__(self): 103 keys = self._storage.keys() 104 for key in keys: 105 unwrapped = key.unwrapped 106 if unwrapped is None: 107 del self[key] 108 else: 109 yield unwrapped 110 111 112class ObjectIdentitySet(collections.MutableSet): 113 """Like the built-in set, but compares objects with "is".""" 114 115 def __init__(self, *args): 116 self._storage = set([self._wrap_key(obj) for obj in list(*args)]) 117 118 def _wrap_key(self, key): 119 return _ObjectIdentityWrapper(key) 120 121 def __contains__(self, key): 122 return self._wrap_key(key) in self._storage 123 124 def discard(self, key): 125 self._storage.discard(self._wrap_key(key)) 126 127 def add(self, key): 128 self._storage.add(self._wrap_key(key)) 129 130 def __len__(self): 131 return len(self._storage) 132 133 def __iter__(self): 134 keys = list(self._storage) 135 for key in keys: 136 yield key.unwrapped 137 138 139class ObjectIdentityWeakSet(ObjectIdentitySet): 140 """Like weakref.WeakSet, but compares objects with "is".""" 141 142 def _wrap_key(self, key): 143 return _WeakObjectIdentityWrapper(key) 144 145 def __len__(self): 146 # Iterate, discarding old weak refs 147 return len([_ for _ in self]) 148 149 def __iter__(self): 150 keys = list(self._storage) 151 for key in keys: 152 unwrapped = key.unwrapped 153 if unwrapped is None: 154 self.discard(key) 155 else: 156 yield unwrapped 157