1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Live entity inspection utilities. 16 17This module contains whatever inspect doesn't offer out of the box. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import itertools 25import types 26 27import six 28 29from tensorflow.python.util import tf_inspect 30 31 32# These functions test negative for isinstance(*, types.BuiltinFunctionType) 33# and inspect.isbuiltin, and are generally not visible in globals(). 34# TODO(mdan): Remove this. 35SPECIAL_BUILTINS = { 36 'dict': dict, 37 'enumerate': enumerate, 38 'float': float, 39 'int': int, 40 'len': len, 41 'list': list, 42 'print': print, 43 'range': range, 44 'tuple': tuple, 45 'type': type, 46 'zip': zip 47} 48 49if six.PY2: 50 SPECIAL_BUILTINS['xrange'] = xrange 51 52 53def islambda(f): 54 if not tf_inspect.isfunction(f): 55 return False 56 if not hasattr(f, '__name__'): 57 return False 58 return f.__name__ == '<lambda>' 59 60 61def isnamedtuple(f): 62 """Returns True if the argument is a namedtuple-like.""" 63 if not (tf_inspect.isclass(f) and issubclass(f, tuple)): 64 return False 65 if not hasattr(f, '_fields'): 66 return False 67 fields = getattr(f, '_fields') 68 if not isinstance(fields, tuple): 69 return False 70 if not all(isinstance(f, str) for f in fields): 71 return False 72 return True 73 74 75def isbuiltin(f): 76 """Returns True if the argument is a built-in function.""" 77 if f in six.moves.builtins.__dict__.values(): 78 return True 79 if isinstance(f, types.BuiltinFunctionType): 80 return True 81 if tf_inspect.isbuiltin(f): 82 return True 83 return False 84 85 86def getnamespace(f): 87 """Returns the complete namespace of a function. 88 89 Namespace is defined here as the mapping of all non-local variables to values. 90 This includes the globals and the closure variables. Note that this captures 91 the entire globals collection of the function, and may contain extra symbols 92 that it does not actually use. 93 94 Args: 95 f: User defined function. 96 Returns: 97 A dict mapping symbol names to values. 98 """ 99 namespace = dict(six.get_function_globals(f)) 100 closure = six.get_function_closure(f) 101 freevars = six.get_function_code(f).co_freevars 102 if freevars and closure: 103 for name, cell in zip(freevars, closure): 104 namespace[name] = cell.cell_contents 105 return namespace 106 107 108def getqualifiedname(namespace, object_, max_depth=5, visited=None): 109 """Returns the name by which a value can be referred to in a given namespace. 110 111 If the object defines a parent module, the function attempts to use it to 112 locate the object. 113 114 This function will recurse inside modules, but it will not search objects for 115 attributes. The recursion depth is controlled by max_depth. 116 117 Args: 118 namespace: Dict[str, Any], the namespace to search into. 119 object_: Any, the value to search. 120 max_depth: Optional[int], a limit to the recursion depth when searching 121 inside modules. 122 visited: Optional[Set[int]], ID of modules to avoid visiting. 123 Returns: Union[str, None], the fully-qualified name that resolves to the value 124 o, or None if it couldn't be found. 125 """ 126 if visited is None: 127 visited = set() 128 129 # Copy the dict to avoid "changed size error" during concurrent invocations. 130 # TODO(mdan): This is on the hot path. Can we avoid the copy? 131 namespace = dict(namespace) 132 133 for name in namespace: 134 # The value may be referenced by more than one symbol, case in which 135 # any symbol will be fine. If the program contains symbol aliases that 136 # change over time, this may capture a symbol that will later point to 137 # something else. 138 # TODO(mdan): Prefer the symbol that matches the value type name. 139 if object_ is namespace[name]: 140 return name 141 142 # If an object is not found, try to search its parent modules. 143 parent = tf_inspect.getmodule(object_) 144 if (parent is not None and parent is not object_ and 145 parent is not namespace): 146 # No limit to recursion depth because of the guard above. 147 parent_name = getqualifiedname( 148 namespace, parent, max_depth=0, visited=visited) 149 if parent_name is not None: 150 name_in_parent = getqualifiedname( 151 parent.__dict__, object_, max_depth=0, visited=visited) 152 assert name_in_parent is not None, ( 153 'An object should always be found in its owner module') 154 return '{}.{}'.format(parent_name, name_in_parent) 155 156 if max_depth: 157 # Iterating over a copy prevents "changed size due to iteration" errors. 158 # It's unclear why those occur - suspecting new modules may load during 159 # iteration. 160 for name in namespace.keys(): 161 value = namespace[name] 162 if tf_inspect.ismodule(value) and id(value) not in visited: 163 visited.add(id(value)) 164 name_in_module = getqualifiedname(value.__dict__, object_, 165 max_depth - 1, visited) 166 if name_in_module is not None: 167 return '{}.{}'.format(name, name_in_module) 168 return None 169 170 171def _get_unbound_function(m): 172 # TODO(mdan): Figure out why six.get_unbound_function fails in some cases. 173 # The failure case is for tf.keras.Model. 174 if hasattr(m, 'im_func'): 175 return m.im_func 176 return m 177 178 179def getdefiningclass(m, owner_class): 180 """Resolves the class (e.g. one of the superclasses) that defined a method.""" 181 # Normalize bound functions to their respective unbound versions. 182 m = _get_unbound_function(m) 183 for superclass in owner_class.__bases__: 184 if hasattr(superclass, m.__name__): 185 superclass_m = getattr(superclass, m.__name__) 186 if _get_unbound_function(superclass_m) is m: 187 return superclass 188 elif hasattr(m, '__self__') and m.__self__ == owner_class: 189 # Python 3 class methods only work this way it seems :S 190 return superclass 191 return owner_class 192 193 194def istfmethodtarget(m): 195 """Tests whether an object is a `function.TfMethodTarget`.""" 196 # See eager.function.TfMethodTarget for more details. 197 return (hasattr(m, '__self__') and 198 hasattr(m.__self__, 'weakrefself_target__') and 199 hasattr(m.__self__, 'weakrefself_func__')) 200 201 202def getmethodself(m): 203 """An extended version of inspect.getmethodclass.""" 204 if not hasattr(m, '__self__'): 205 return None 206 if m.__self__ is None: 207 return None 208 209 # A fallback allowing methods to be actually bound to a type different 210 # than __self__. This is useful when a strong reference from the method 211 # to the object is not desired, for example when caching is involved. 212 if istfmethodtarget(m): 213 return m.__self__.target 214 215 return m.__self__ 216 217 218def getmethodclass(m): 219 """Resolves a function's owner, e.g. a method's class. 220 221 Note that this returns the object that the function was retrieved from, not 222 necessarily the class where it was defined. 223 224 This function relies on Python stack frame support in the interpreter, and 225 has the same limitations that inspect.currentframe. 226 227 Limitations. This function will only work correctly if the owned class is 228 visible in the caller's global or local variables. 229 230 Args: 231 m: A user defined function 232 233 Returns: 234 The class that this function was retrieved from, or None if the function 235 is not an object or class method, or the class that owns the object or 236 method is not visible to m. 237 238 Raises: 239 ValueError: if the class could not be resolved for any unexpected reason. 240 """ 241 242 # Callable objects: return their own class. 243 if (not hasattr(m, '__name__') and hasattr(m, '__class__') and 244 hasattr(m, '__call__')): 245 if isinstance(m.__class__, six.class_types): 246 return m.__class__ 247 248 # Instance method and class methods: return the class of "self". 249 m_self = getmethodself(m) 250 if m_self is not None: 251 if tf_inspect.isclass(m_self): 252 return m_self 253 return m_self.__class__ 254 255 # Class, static and unbound methods: search all defined classes in any 256 # namespace. This is inefficient but more robust method. 257 owners = [] 258 caller_frame = tf_inspect.currentframe().f_back 259 try: 260 # TODO(mdan): This doesn't consider cell variables. 261 # TODO(mdan): This won't work if the owner is hidden inside a container. 262 # Cell variables may be pulled using co_freevars and the closure. 263 for v in itertools.chain(caller_frame.f_locals.values(), 264 caller_frame.f_globals.values()): 265 if hasattr(v, m.__name__): 266 candidate = getattr(v, m.__name__) 267 # Py2 methods may be bound or unbound, extract im_func to get the 268 # underlying function. 269 if hasattr(candidate, 'im_func'): 270 candidate = candidate.im_func 271 if hasattr(m, 'im_func'): 272 m = m.im_func 273 if candidate is m: 274 owners.append(v) 275 finally: 276 del caller_frame 277 278 if owners: 279 if len(owners) == 1: 280 return owners[0] 281 282 # If multiple owners are found, and are not subclasses, raise an error. 283 owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners) 284 for o in owner_types: 285 if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)): 286 return o 287 raise ValueError('Found too many owners of %s: %s' % (m, owners)) 288 289 return None 290 291 292def getfutureimports(entity): 293 """Detects what future imports are necessary to safely execute entity source. 294 295 Args: 296 entity: Any object 297 298 Returns: 299 A tuple of future strings 300 """ 301 if not tf_inspect.isfunction(entity): 302 return tuple() 303 return tuple(sorted(name for name, value in entity.__globals__.items() 304 if getattr(value, '__module__', None) == '__future__')) 305 306 307class SuperWrapperForDynamicAttrs(object): 308 """A wrapper that supports dynamic attribute lookup on the super object. 309 310 For example, in the following code, `super` incorrectly reports that 311 `super(Bar, b)` lacks the `a` attribute: 312 313 class Foo(object): 314 def __init__(self): 315 self.a = lambda: 1 316 317 def bar(self): 318 return hasattr(self, 'a') 319 320 class Bar(Foo): 321 def bar(self): 322 return super(Bar, self).bar() 323 324 325 b = Bar() 326 print(hasattr(super(Bar, b), 'a')) # False 327 print(super(Bar, b).bar()) # True 328 329 A practical situation when this tends to happen is Keras model hierarchies 330 that hold references to certain layers, like this: 331 332 class MiniModel(keras.Model): 333 334 def __init__(self): 335 super(MiniModel, self).__init__() 336 self.fc = keras.layers.Dense(1) 337 338 def call(self, inputs, training=True): 339 return self.fc(inputs) 340 341 class DefunnedMiniModel(MiniModel): 342 343 def call(self, inputs, training=True): 344 return super(DefunnedMiniModel, self).call(inputs, training=training) 345 346 A side effect of this wrapper is that all attributes become visible, even 347 those created in the subclass. 348 """ 349 350 # TODO(mdan): Investigate why that happens - it may be for a reason. 351 # TODO(mdan): Probably need more overrides to make it look like super. 352 353 def __init__(self, target): 354 self._target = target 355 356 def __getattribute__(self, name): 357 target = object.__getattribute__(self, '_target') 358 if hasattr(target, name): 359 return getattr(target, name) 360 return getattr(target.__self__, name) 361