• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Live entity inspection utilities.
16
17This module contains whatever inspect doesn't offer out of the box.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import itertools
25import types
26
27import six
28
29from tensorflow.python.util import tf_inspect
30
31
32# These functions test negative for isinstance(*, types.BuiltinFunctionType)
33# and inspect.isbuiltin, and are generally not visible in globals().
34# TODO(mdan): Remove this.
35SPECIAL_BUILTINS = {
36    'dict': dict,
37    'enumerate': enumerate,
38    'float': float,
39    'int': int,
40    'len': len,
41    'list': list,
42    'print': print,
43    'range': range,
44    'tuple': tuple,
45    'type': type,
46    'zip': zip
47}
48
49if six.PY2:
50  SPECIAL_BUILTINS['xrange'] = xrange
51
52
53def islambda(f):
54  if not tf_inspect.isfunction(f):
55    return False
56  if not hasattr(f, '__name__'):
57    return False
58  return f.__name__ == '<lambda>'
59
60
61def isnamedtuple(f):
62  """Returns True if the argument is a namedtuple-like."""
63  if not (tf_inspect.isclass(f) and issubclass(f, tuple)):
64    return False
65  if not hasattr(f, '_fields'):
66    return False
67  fields = getattr(f, '_fields')
68  if not isinstance(fields, tuple):
69    return False
70  if not all(isinstance(f, str) for f in fields):
71    return False
72  return True
73
74
75def isbuiltin(f):
76  """Returns True if the argument is a built-in function."""
77  if f in six.moves.builtins.__dict__.values():
78    return True
79  if isinstance(f, types.BuiltinFunctionType):
80    return True
81  if tf_inspect.isbuiltin(f):
82    return True
83  return False
84
85
86def getnamespace(f):
87  """Returns the complete namespace of a function.
88
89  Namespace is defined here as the mapping of all non-local variables to values.
90  This includes the globals and the closure variables. Note that this captures
91  the entire globals collection of the function, and may contain extra symbols
92  that it does not actually use.
93
94  Args:
95    f: User defined function.
96  Returns:
97    A dict mapping symbol names to values.
98  """
99  namespace = dict(six.get_function_globals(f))
100  closure = six.get_function_closure(f)
101  freevars = six.get_function_code(f).co_freevars
102  if freevars and closure:
103    for name, cell in zip(freevars, closure):
104      namespace[name] = cell.cell_contents
105  return namespace
106
107
108def getqualifiedname(namespace, object_, max_depth=5, visited=None):
109  """Returns the name by which a value can be referred to in a given namespace.
110
111  If the object defines a parent module, the function attempts to use it to
112  locate the object.
113
114  This function will recurse inside modules, but it will not search objects for
115  attributes. The recursion depth is controlled by max_depth.
116
117  Args:
118    namespace: Dict[str, Any], the namespace to search into.
119    object_: Any, the value to search.
120    max_depth: Optional[int], a limit to the recursion depth when searching
121        inside modules.
122    visited: Optional[Set[int]], ID of modules to avoid visiting.
123  Returns: Union[str, None], the fully-qualified name that resolves to the value
124      o, or None if it couldn't be found.
125  """
126  if visited is None:
127    visited = set()
128
129  # Copy the dict to avoid "changed size error" during concurrent invocations.
130  # TODO(mdan): This is on the hot path. Can we avoid the copy?
131  namespace = dict(namespace)
132
133  for name in namespace:
134    # The value may be referenced by more than one symbol, case in which
135    # any symbol will be fine. If the program contains symbol aliases that
136    # change over time, this may capture a symbol that will later point to
137    # something else.
138    # TODO(mdan): Prefer the symbol that matches the value type name.
139    if object_ is namespace[name]:
140      return name
141
142  # If an object is not found, try to search its parent modules.
143  parent = tf_inspect.getmodule(object_)
144  if (parent is not None and parent is not object_ and
145      parent is not namespace):
146    # No limit to recursion depth because of the guard above.
147    parent_name = getqualifiedname(
148        namespace, parent, max_depth=0, visited=visited)
149    if parent_name is not None:
150      name_in_parent = getqualifiedname(
151          parent.__dict__, object_, max_depth=0, visited=visited)
152      assert name_in_parent is not None, (
153          'An object should always be found in its owner module')
154      return '{}.{}'.format(parent_name, name_in_parent)
155
156  if max_depth:
157    # Iterating over a copy prevents "changed size due to iteration" errors.
158    # It's unclear why those occur - suspecting new modules may load during
159    # iteration.
160    for name in namespace.keys():
161      value = namespace[name]
162      if tf_inspect.ismodule(value) and id(value) not in visited:
163        visited.add(id(value))
164        name_in_module = getqualifiedname(value.__dict__, object_,
165                                          max_depth - 1, visited)
166        if name_in_module is not None:
167          return '{}.{}'.format(name, name_in_module)
168  return None
169
170
171def _get_unbound_function(m):
172  # TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
173  # The failure case is for tf.keras.Model.
174  if hasattr(m, 'im_func'):
175    return m.im_func
176  return m
177
178
179def getdefiningclass(m, owner_class):
180  """Resolves the class (e.g. one of the superclasses) that defined a method."""
181  # Normalize bound functions to their respective unbound versions.
182  m = _get_unbound_function(m)
183  for superclass in owner_class.__bases__:
184    if hasattr(superclass, m.__name__):
185      superclass_m = getattr(superclass, m.__name__)
186      if _get_unbound_function(superclass_m) is m:
187        return superclass
188      elif hasattr(m, '__self__') and m.__self__ == owner_class:
189        # Python 3 class methods only work this way it seems :S
190        return superclass
191  return owner_class
192
193
194def istfmethodtarget(m):
195  """Tests whether an object is a `function.TfMethodTarget`."""
196  # See eager.function.TfMethodTarget for more details.
197  return (hasattr(m, '__self__') and
198          hasattr(m.__self__, 'weakrefself_target__') and
199          hasattr(m.__self__, 'weakrefself_func__'))
200
201
202def getmethodself(m):
203  """An extended version of inspect.getmethodclass."""
204  if not hasattr(m, '__self__'):
205    return None
206  if m.__self__ is None:
207    return None
208
209  # A fallback allowing methods to be actually bound to a type different
210  # than __self__. This is useful when a strong reference from the method
211  # to the object is not desired, for example when caching is involved.
212  if istfmethodtarget(m):
213    return m.__self__.target
214
215  return m.__self__
216
217
218def getmethodclass(m):
219  """Resolves a function's owner, e.g. a method's class.
220
221  Note that this returns the object that the function was retrieved from, not
222  necessarily the class where it was defined.
223
224  This function relies on Python stack frame support in the interpreter, and
225  has the same limitations that inspect.currentframe.
226
227  Limitations. This function will only work correctly if the owned class is
228  visible in the caller's global or local variables.
229
230  Args:
231    m: A user defined function
232
233  Returns:
234    The class that this function was retrieved from, or None if the function
235    is not an object or class method, or the class that owns the object or
236    method is not visible to m.
237
238  Raises:
239    ValueError: if the class could not be resolved for any unexpected reason.
240  """
241
242  # Callable objects: return their own class.
243  if (not hasattr(m, '__name__') and hasattr(m, '__class__') and
244      hasattr(m, '__call__')):
245    if isinstance(m.__class__, six.class_types):
246      return m.__class__
247
248  # Instance method and class methods: return the class of "self".
249  m_self = getmethodself(m)
250  if m_self is not None:
251    if tf_inspect.isclass(m_self):
252      return m_self
253    return m_self.__class__
254
255  # Class, static and unbound methods: search all defined classes in any
256  # namespace. This is inefficient but more robust method.
257  owners = []
258  caller_frame = tf_inspect.currentframe().f_back
259  try:
260    # TODO(mdan): This doesn't consider cell variables.
261    # TODO(mdan): This won't work if the owner is hidden inside a container.
262    # Cell variables may be pulled using co_freevars and the closure.
263    for v in itertools.chain(caller_frame.f_locals.values(),
264                             caller_frame.f_globals.values()):
265      if hasattr(v, m.__name__):
266        candidate = getattr(v, m.__name__)
267        # Py2 methods may be bound or unbound, extract im_func to get the
268        # underlying function.
269        if hasattr(candidate, 'im_func'):
270          candidate = candidate.im_func
271        if hasattr(m, 'im_func'):
272          m = m.im_func
273        if candidate is m:
274          owners.append(v)
275  finally:
276    del caller_frame
277
278  if owners:
279    if len(owners) == 1:
280      return owners[0]
281
282    # If multiple owners are found, and are not subclasses, raise an error.
283    owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners)
284    for o in owner_types:
285      if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)):
286        return o
287    raise ValueError('Found too many owners of %s: %s' % (m, owners))
288
289  return None
290
291
292def getfutureimports(entity):
293  """Detects what future imports are necessary to safely execute entity source.
294
295  Args:
296    entity: Any object
297
298  Returns:
299    A tuple of future strings
300  """
301  if not tf_inspect.isfunction(entity):
302    return tuple()
303  return tuple(sorted(name for name, value in entity.__globals__.items()
304                      if getattr(value, '__module__', None) == '__future__'))
305
306
307class SuperWrapperForDynamicAttrs(object):
308  """A wrapper that supports dynamic attribute lookup on the super object.
309
310  For example, in the following code, `super` incorrectly reports that
311  `super(Bar, b)` lacks the `a` attribute:
312
313    class Foo(object):
314      def __init__(self):
315        self.a = lambda: 1
316
317      def bar(self):
318        return hasattr(self, 'a')
319
320    class Bar(Foo):
321      def bar(self):
322        return super(Bar, self).bar()
323
324
325    b = Bar()
326    print(hasattr(super(Bar, b), 'a'))  # False
327    print(super(Bar, b).bar())          # True
328
329  A practical situation when this tends to happen is Keras model hierarchies
330  that hold references to certain layers, like this:
331
332    class MiniModel(keras.Model):
333
334      def __init__(self):
335        super(MiniModel, self).__init__()
336        self.fc = keras.layers.Dense(1)
337
338      def call(self, inputs, training=True):
339        return self.fc(inputs)
340
341    class DefunnedMiniModel(MiniModel):
342
343      def call(self, inputs, training=True):
344        return super(DefunnedMiniModel, self).call(inputs, training=training)
345
346  A side effect of this wrapper is that all attributes become visible, even
347  those created in the subclass.
348  """
349
350  # TODO(mdan): Investigate why that happens - it may be for a reason.
351  # TODO(mdan): Probably need more overrides to make it look like super.
352
353  def __init__(self, target):
354    self._target = target
355
356  def __getattribute__(self, name):
357    target = object.__getattribute__(self, '_target')
358    if hasattr(target, name):
359      return getattr(target, name)
360    return getattr(target.__self__, name)
361