• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Live entity inspection utilities.
16
17This module contains whatever inspect doesn't offer out of the box.
18"""
19
20import builtins
21import inspect
22import itertools
23import linecache
24import sys
25import threading
26import types
27
28from tensorflow.python.util import tf_inspect
29
30# This lock seems to help avoid linecache concurrency errors.
31_linecache_lock = threading.Lock()
32
33
34def islambda(f):
35  if not tf_inspect.isfunction(f):
36    return False
37  # TODO(mdan): Look into checking the only the code object.
38  if not (hasattr(f, '__name__') and hasattr(f, '__code__')):
39    return False
40  # Some wrappers can rename the function, but changing the name of the
41  # code object is harder.
42  return ((f.__name__ == '<lambda>') or (f.__code__.co_name == '<lambda>'))
43
44
45def isnamedtuple(f):
46  """Returns True if the argument is a namedtuple-like."""
47  if not (tf_inspect.isclass(f) and issubclass(f, tuple)):
48    return False
49  if not hasattr(f, '_fields'):
50    return False
51  fields = getattr(f, '_fields')
52  if not isinstance(fields, tuple):
53    return False
54  if not all(isinstance(f, str) for f in fields):
55    return False
56  return True
57
58
59def isbuiltin(f):
60  """Returns True if the argument is a built-in function."""
61  if any(f is builtin for builtin in builtins.__dict__.values()):
62    return True
63  elif isinstance(f, types.BuiltinFunctionType):
64    return True
65  elif inspect.isbuiltin(f):
66    return True
67  elif f is eval:
68    return True
69  else:
70    return False
71
72
73def isconstructor(cls):
74  """Returns True if the argument is an object constructor.
75
76  In general, any object of type class is a constructor, with the exception
77  of classes created using a callable metaclass.
78  See below for why a callable metaclass is not a trivial combination:
79  https://docs.python.org/2.7/reference/datamodel.html#customizing-class-creation
80
81  Args:
82    cls: Any
83
84  Returns:
85    Bool
86  """
87  return (inspect.isclass(cls) and
88          not (issubclass(cls.__class__, type) and
89               hasattr(cls.__class__, '__call__') and
90               cls.__class__.__call__ is not type.__call__))
91
92
93def _fix_linecache_record(obj):
94  """Fixes potential corruption of linecache in the presence of functools.wraps.
95
96  functools.wraps modifies the target object's __module__ field, which seems
97  to confuse linecache in special instances, for example when the source is
98  loaded from a .par file (see https://google.github.io/subpar/subpar.html).
99
100  This function simply triggers a call to linecache.updatecache when a mismatch
101  was detected between the object's __module__ property and the object's source
102  file.
103
104  Args:
105    obj: Any
106  """
107  if hasattr(obj, '__module__'):
108    obj_file = inspect.getfile(obj)
109    obj_module = obj.__module__
110
111    # A snapshot of the loaded modules helps avoid "dict changed size during
112    # iteration" errors.
113    loaded_modules = tuple(sys.modules.values())
114    for m in loaded_modules:
115      if hasattr(m, '__file__') and m.__file__ == obj_file:
116        if obj_module is not m:
117          linecache.updatecache(obj_file, m.__dict__)
118
119
120def getimmediatesource(obj):
121  """A variant of inspect.getsource that ignores the __wrapped__ property."""
122  with _linecache_lock:
123    _fix_linecache_record(obj)
124    lines, lnum = inspect.findsource(obj)
125    return ''.join(inspect.getblock(lines[lnum:]))
126
127
128def getnamespace(f):
129  """Returns the complete namespace of a function.
130
131  Namespace is defined here as the mapping of all non-local variables to values.
132  This includes the globals and the closure variables. Note that this captures
133  the entire globals collection of the function, and may contain extra symbols
134  that it does not actually use.
135
136  Args:
137    f: User defined function.
138
139  Returns:
140    A dict mapping symbol names to values.
141  """
142  namespace = dict(f.__globals__)
143  closure = f.__closure__
144  freevars = f.__code__.co_freevars
145  if freevars and closure:
146    for name, cell in zip(freevars, closure):
147      try:
148        namespace[name] = cell.cell_contents
149      except ValueError:
150        # Cell contains undefined variable, omit it from the namespace.
151        pass
152  return namespace
153
154
155def getqualifiedname(namespace, object_, max_depth=5, visited=None):
156  """Returns the name by which a value can be referred to in a given namespace.
157
158  If the object defines a parent module, the function attempts to use it to
159  locate the object.
160
161  This function will recurse inside modules, but it will not search objects for
162  attributes. The recursion depth is controlled by max_depth.
163
164  Args:
165    namespace: Dict[str, Any], the namespace to search into.
166    object_: Any, the value to search.
167    max_depth: Optional[int], a limit to the recursion depth when searching
168      inside modules.
169    visited: Optional[Set[int]], ID of modules to avoid visiting.
170  Returns: Union[str, None], the fully-qualified name that resolves to the value
171    o, or None if it couldn't be found.
172  """
173  if visited is None:
174    visited = set()
175
176  # Copy the dict to avoid "changed size error" during concurrent invocations.
177  # TODO(mdan): This is on the hot path. Can we avoid the copy?
178  namespace = dict(namespace)
179
180  for name in namespace:
181    # The value may be referenced by more than one symbol, case in which
182    # any symbol will be fine. If the program contains symbol aliases that
183    # change over time, this may capture a symbol that will later point to
184    # something else.
185    # TODO(mdan): Prefer the symbol that matches the value type name.
186    if object_ is namespace[name]:
187      return name
188
189  # If an object is not found, try to search its parent modules.
190  parent = tf_inspect.getmodule(object_)
191  if (parent is not None and parent is not object_ and parent is not namespace):
192    # No limit to recursion depth because of the guard above.
193    parent_name = getqualifiedname(
194        namespace, parent, max_depth=0, visited=visited)
195    if parent_name is not None:
196      name_in_parent = getqualifiedname(
197          parent.__dict__, object_, max_depth=0, visited=visited)
198      assert name_in_parent is not None, (
199          'An object should always be found in its owner module')
200      return '{}.{}'.format(parent_name, name_in_parent)
201
202  if max_depth:
203    # Iterating over a copy prevents "changed size due to iteration" errors.
204    # It's unclear why those occur - suspecting new modules may load during
205    # iteration.
206    for name in namespace.keys():
207      value = namespace[name]
208      if tf_inspect.ismodule(value) and id(value) not in visited:
209        visited.add(id(value))
210        name_in_module = getqualifiedname(value.__dict__, object_,
211                                          max_depth - 1, visited)
212        if name_in_module is not None:
213          return '{}.{}'.format(name, name_in_module)
214  return None
215
216
217def getdefiningclass(m, owner_class):
218  """Resolves the class (e.g. one of the superclasses) that defined a method."""
219  method_name = m.__name__
220  for super_class in inspect.getmro(owner_class):
221    if ((hasattr(super_class, '__dict__') and
222         method_name in super_class.__dict__) or
223        (hasattr(super_class, '__slots__') and
224         method_name in super_class.__slots__)):
225      return super_class
226  return owner_class
227
228
229def getmethodclass(m):
230  """Resolves a function's owner, e.g.
231
232  a method's class.
233
234  Note that this returns the object that the function was retrieved from, not
235  necessarily the class where it was defined.
236
237  This function relies on Python stack frame support in the interpreter, and
238  has the same limitations that inspect.currentframe.
239
240  Limitations. This function will only work correctly if the owned class is
241  visible in the caller's global or local variables.
242
243  Args:
244    m: A user defined function
245
246  Returns:
247    The class that this function was retrieved from, or None if the function
248    is not an object or class method, or the class that owns the object or
249    method is not visible to m.
250
251  Raises:
252    ValueError: if the class could not be resolved for any unexpected reason.
253  """
254
255  # Callable objects: return their own class.
256  if (not hasattr(m, '__name__') and hasattr(m, '__class__') and
257      hasattr(m, '__call__')):
258    if isinstance(m.__class__, type):
259      return m.__class__
260
261  # Instance and class: return the class of "self".
262  m_self = getattr(m, '__self__', None)
263  if m_self is not None:
264    if inspect.isclass(m_self):
265      return m_self
266    return m_self.__class__
267
268  # Class, static and unbound methods: search all defined classes in any
269  # namespace. This is inefficient but more robust a method.
270  owners = []
271  caller_frame = tf_inspect.currentframe().f_back
272  try:
273    # TODO(mdan): This doesn't consider cell variables.
274    # TODO(mdan): This won't work if the owner is hidden inside a container.
275    # Cell variables may be pulled using co_freevars and the closure.
276    for v in itertools.chain(caller_frame.f_locals.values(),
277                             caller_frame.f_globals.values()):
278      if hasattr(v, m.__name__):
279        candidate = getattr(v, m.__name__)
280        # Py2 methods may be bound or unbound, extract im_func to get the
281        # underlying function.
282        if hasattr(candidate, 'im_func'):
283          candidate = candidate.im_func
284        if hasattr(m, 'im_func'):
285          m = m.im_func
286        if candidate is m:
287          owners.append(v)
288  finally:
289    del caller_frame
290
291  if owners:
292    if len(owners) == 1:
293      return owners[0]
294
295    # If multiple owners are found, and are not subclasses, raise an error.
296    owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners)
297    for o in owner_types:
298      if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)):
299        return o
300    raise ValueError('Found too many owners of %s: %s' % (m, owners))
301
302  return None
303
304
305def getfutureimports(entity):
306  """Detects what future imports are necessary to safely execute entity source.
307
308  Args:
309    entity: Any object
310
311  Returns:
312    A tuple of future strings
313  """
314  if not (tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity)):
315    return tuple()
316  return tuple(
317      sorted(name for name, value in entity.__globals__.items()
318             if getattr(value, '__module__', None) == '__future__'))
319