• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# ==============================================================================
16"""A visitor class that generates protobufs for each python object."""
17
18import enum
19import sys
20
21from google.protobuf import message
22from tensorflow.python.platform import tf_logging as logging
23from tensorflow.python.util import deprecation
24from tensorflow.python.util import tf_decorator
25from tensorflow.python.util import tf_inspect
26from tensorflow.tools.api.lib import api_objects_pb2
27
28# Following object need to be handled individually.
29_CORNER_CASES = {
30    '': {
31        'tools': {}
32    },
33    'test.TestCase': {},
34    'test.TestCase.failureException': {},
35    'train.NanLossDuringTrainingError': {
36        'message': {}
37    },
38    'estimator.NanLossDuringTrainingError': {
39        'message': {}
40    },
41    'train.LooperThread': {
42        'isAlive': {},
43        'join': {},
44        'native_id': {}
45    }
46}
47
48# Python 2 vs. 3 differences
49if sys.version_info.major == 3:
50  _NORMALIZE_TYPE = {}
51  for t in ('property', 'object', 'getset_descriptor', 'int', 'str', 'type',
52            'tuple', 'module', 'collections.defaultdict', 'set', 'dict',
53            'NoneType', 'frozenset', 'member_descriptor'):
54    _NORMALIZE_TYPE["<class '%s'>" % t] = "<type '%s'>" % t
55  for e in 'Exception', 'RuntimeError':
56    _NORMALIZE_TYPE["<class '%s'>" % e] = "<type 'exceptions.%s'>" % e
57  _NORMALIZE_TYPE["<class 'abc.ABCMeta'>"] = "<type 'type'>"
58  _NORMALIZE_ISINSTANCE = {
59      "<class "
60      "'tensorflow.lite.python.op_hint.OpHint.OpHintArgumentTracker'>":  # pylint: disable=line-too-long
61          "<class "
62          "'tensorflow.lite.python.op_hint.OpHintArgumentTracker'>",
63      "<class "
64      "'tensorflow.python.training.monitored_session._MonitoredSession.StepContext'>":  # pylint: disable=line-too-long
65          "<class "
66          "'tensorflow.python.training.monitored_session.StepContext'>",
67      "<class "
68      "'tensorflow.python.ops.variables.Variable.SaveSliceInfo'>":
69          "<class "
70          "'tensorflow.python.ops.variables.SaveSliceInfo'>"
71  }
72
73  def _SkipMember(cls, member):
74    return (member == 'with_traceback' or member in ('name', 'value') and
75            isinstance(cls, type) and issubclass(cls, enum.Enum))
76else:
77  _NORMALIZE_TYPE = {
78      "<class 'abc.ABCMeta'>":
79          "<type 'type'>",
80      "<class 'pybind11_type'>":
81          "<class 'pybind11_builtins.pybind11_type'>",
82  }
83  _NORMALIZE_ISINSTANCE = {
84      "<class 'pybind11_object'>":
85          "<class 'pybind11_builtins.pybind11_object'>",
86  }
87
88  def _SkipMember(cls, member):  # pylint: disable=unused-argument
89    return False
90
91
92# Differences created by typing implementations.
93_NORMALIZE_TYPE[(
94    'tensorflow.python.framework.ops.Tensor')] = (
95        "<class 'tensorflow.python.framework.ops.Tensor'>")
96_NORMALIZE_TYPE['typing.Generic'] = "<class 'typing.Generic'>"
97# TODO(b/203104448): Remove once the golden files are generated in Python 3.7.
98_NORMALIZE_TYPE["<class 'typing._GenericAlias'>"] = 'typing.Union'
99# TODO(b/203104448): Remove once the golden files are generated in Python 3.9.
100_NORMALIZE_TYPE["<class 'typing._UnionGenericAlias'>"] = 'typing.Union'
101# TODO(b/203104448): Remove once the golden files are generated in Python 3.8.
102_NORMALIZE_TYPE[
103    "<class 'typing_extensions._ProtocolMeta'>"] = ("<class "
104                                                    "'typing._ProtocolMeta'>")
105# TODO(b/203104448): Remove once the golden files are generated in Python 3.8.
106_NORMALIZE_TYPE[
107    "<class 'typing_extensions.Protocol'>"] = "<class 'typing.Protocol'>"
108
109if sys.version_info.major == 3 and sys.version_info.minor >= 8:
110  _NORMALIZE_TYPE["<class '_collections._tuplegetter'>"] = "<type 'property'>"
111
112
113def _NormalizeType(ty):
114  return _NORMALIZE_TYPE.get(ty, ty)
115
116
117def _NormalizeIsInstance(ty):
118  return _NORMALIZE_ISINSTANCE.get(ty, ty)
119
120
121def _SanitizedArgSpec(obj):
122  """Get an ArgSpec string that is free of addresses.
123
124  We have callables as function arg defaults. This results in addresses in
125  getargspec output. This function returns a sanitized string list of base
126  classes.
127
128  Args:
129    obj: A python routine for us the create the sanitized arspec of.
130
131  Returns:
132    string, a string representation of the argspec.
133  """
134  output_string = ''
135  unsanitized_arg_spec = tf_inspect.getargspec(obj)
136
137  for clean_attr in ('args', 'varargs', 'keywords'):
138    output_string += '%s=%s, ' % (clean_attr,
139                                  getattr(unsanitized_arg_spec, clean_attr))
140
141  if unsanitized_arg_spec.defaults:
142    sanitized_defaults = []
143    for val in unsanitized_arg_spec.defaults:
144      str_val = str(val)
145      # Sanitize argspecs that have hex code in them.
146      if ' at 0x' in str_val:
147        sanitized_defaults.append('%s instance>' % str_val.split(' at ')[0])
148      else:
149        sanitized_defaults.append(str_val)
150
151    output_string += 'defaults=%s, ' % sanitized_defaults
152
153  else:
154    output_string += 'defaults=None'
155
156  return output_string
157
158
159def _SanitizedMRO(obj):
160  """Get a list of superclasses with minimal amount of non-TF classes.
161
162  Based on many parameters like python version, OS, protobuf implementation
163  or changes in google core libraries the list of superclasses of a class
164  can change. We only return the first non-TF class to be robust to non API
165  affecting changes. The Method Resolution Order returned by `tf_inspect.getmro`
166  is still maintained in the return value.
167
168  Args:
169    obj: A python routine for us the create the sanitized arspec of.
170
171  Returns:
172    list of strings, string representation of the class names.
173  """
174  return_list = []
175  for cls in tf_inspect.getmro(obj):
176    if cls.__name__ == '_NewClass':
177      # Ignore class created by @deprecated_alias decorator.
178      continue
179    str_repr = _NormalizeType(str(cls))
180    return_list.append(str_repr)
181    # Class type that has keras in their name should also be monitored. This
182    # will cover any class that imported from third_party/py/keras or
183    # keras_preprocessing.
184    if 'tensorflow' not in str_repr and 'keras' not in str_repr:
185      break
186
187    # Hack - tensorflow.test.StubOutForTesting may or may not be type <object>
188    # depending on the environment. To avoid inconsistency, break after we add
189    # StubOutForTesting to the return_list.
190    if 'StubOutForTesting' in str_repr:
191      break
192
193  return return_list
194
195
196def _IsProtoClass(obj):
197  """Returns whether the passed obj is a Protocol Buffer class."""
198  return isinstance(obj, type) and issubclass(obj, message.Message)
199
200
201class PythonObjectToProtoVisitor:
202  """A visitor that summarizes given python objects as protobufs."""
203
204  def __init__(self, default_path='tensorflow'):
205    # A dict to store all protocol buffers.
206    # Keyed by "path" to the object.
207    self._protos = {}
208    self._default_path = default_path
209
210  def GetProtos(self):
211    """Return the list of protos stored."""
212    return self._protos
213
214  def __call__(self, path, parent, children):
215    # The path to the object.
216    lib_path = self._default_path + '.' + path if path else self._default_path
217    _, parent = tf_decorator.unwrap(parent)
218
219    # A small helper method to construct members(children) protos.
220    def _AddMember(member_name, member_obj, proto):
221      """Add the child object to the object being constructed."""
222      _, member_obj = tf_decorator.unwrap(member_obj)
223      if (_SkipMember(parent, member_name) or
224          isinstance(member_obj, deprecation.HiddenTfApiAttribute)):
225        return
226      if member_name == '__init__' or not member_name.startswith('_'):
227        if tf_inspect.isroutine(member_obj):
228          new_method = proto.member_method.add()
229          new_method.name = member_name
230          # If member_obj is a python builtin, there is no way to get its
231          # argspec, because it is implemented on the C side. It also has no
232          # func_code.
233          if hasattr(member_obj, '__code__'):
234            new_method.argspec = _SanitizedArgSpec(member_obj)
235        else:
236          new_member = proto.member.add()
237          new_member.name = member_name
238          if tf_inspect.ismodule(member_obj):
239            new_member.mtype = "<type \'module\'>"
240          else:
241            new_member.mtype = _NormalizeType(str(type(member_obj)))
242
243    parent_corner_cases = _CORNER_CASES.get(path, {})
244
245    if path not in _CORNER_CASES or parent_corner_cases:
246      # Decide if we have a module or a class.
247      if tf_inspect.ismodule(parent):
248        # Create a module object.
249        module_obj = api_objects_pb2.TFAPIModule()
250        for name, child in children:
251          if name in parent_corner_cases:
252            # If we have an empty entry, skip this object.
253            if parent_corner_cases[name]:
254              module_obj.member.add(**(parent_corner_cases[name]))
255          else:
256            _AddMember(name, child, module_obj)
257
258        # Store the constructed module object.
259        self._protos[lib_path] = api_objects_pb2.TFAPIObject(
260            path=lib_path, tf_module=module_obj)
261      elif _IsProtoClass(parent):
262        proto_obj = api_objects_pb2.TFAPIProto()
263        parent.DESCRIPTOR.CopyToProto(proto_obj.descriptor)
264
265        # Store the constructed proto object.
266        self._protos[lib_path] = api_objects_pb2.TFAPIObject(
267            path=lib_path, tf_proto=proto_obj)
268      elif tf_inspect.isclass(parent):
269        # Construct a class.
270        class_obj = api_objects_pb2.TFAPIClass()
271        class_obj.is_instance.extend(
272            _NormalizeIsInstance(i) for i in _SanitizedMRO(parent))
273        for name, child in children:
274          if name in parent_corner_cases:
275            # If we have an empty entry, skip this object.
276            if parent_corner_cases[name]:
277              class_obj.member.add(**(parent_corner_cases[name]))
278          else:
279            _AddMember(name, child, class_obj)
280
281        # Store the constructed class object.
282        self._protos[lib_path] = api_objects_pb2.TFAPIObject(
283            path=lib_path, tf_class=class_obj)
284      else:
285        logging.error('Illegal call to ApiProtoDump::_py_obj_to_proto.'
286                      'Object is neither a module nor a class: %s', path)
287