• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core conversion logic, serves as main point of access."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import inspect
23import sys
24import unittest
25
26from tensorflow.python.autograph.core import config
27from tensorflow.python.autograph.pyct import cache
28from tensorflow.python.autograph.pyct import inspect_utils
29from tensorflow.python.autograph.utils import ag_logging as logging
30from tensorflow.python.eager import function
31from tensorflow.python.util import tf_inspect
32
33
34_ALLOWLIST_CACHE = cache.UnboundInstanceCache()
35
36
37def _is_of_known_loaded_module(f, module_name):
38  mod = sys.modules.get(module_name, None)
39  if mod is None:
40    return False
41  if any(v is not None for v in mod.__dict__.values() if f is v):
42    return True
43  return False
44
45
46def _is_known_loaded_type(f, module_name, entity_name):
47  """Tests whether the function or method is an instance of a known type."""
48  if (module_name not in sys.modules or
49      not hasattr(sys.modules[module_name], entity_name)):
50    return False
51  type_entity = getattr(sys.modules[module_name], entity_name)
52  if isinstance(f, type_entity):
53    # The method if of this type. Example:
54    #
55    # o = ClassType()
56    # function(o.method)()
57    return True
58  # Note: inspect is required here, to avoid unpacking tf.function decorators.
59  if inspect.ismethod(f):
60    # The unbound method if of this type. Example:
61    #
62    # class ClassType:
63    #   @function
64    #   def method(self):
65    #     ...
66    # o = ClassType()
67    # o.method()
68    if isinstance(f.__func__, type_entity):
69      return True
70  return False
71
72
73def is_unsupported(o):
74  """Checks whether an entity is supported by AutoGraph at all."""
75
76  # TODO(b/122265385): Remove this bypass.
77  if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or
78      _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
79    logging.warn(
80        '{} appears to be decorated by wrapt, which is not yet supported'
81        ' by AutoGraph. The function will run as-is.'
82        ' You may still apply AutoGraph before the wrapt decorator.'.format(o))
83    logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
84    return True
85
86  if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
87    logging.log(2, 'Permanently allowed: %s: lru_cache', o)
88    return True
89
90  # Constructors are permanently allowed.
91  # TODO(mdan): Toggle as experimental feature instead.
92  # TODO(b/124016764): Remove this limitation.
93  if inspect_utils.isconstructor(o):
94    logging.log(2, 'Permanently allowed: %s: constructor', o)
95    return True
96
97  # Other built-in modules are permanently allowed.
98  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
99  if any(
100      _is_of_known_loaded_module(o, m)
101      for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
102    logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
103    return True
104
105  # Custom ops and kernels are also permanently allowed.
106  # See tensorflow.framework.load_library.
107  if (hasattr(o, '__module__') and
108      hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
109    logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
110    return True
111
112  return False
113
114
115# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
116def is_allowlisted(
117    o, check_call_override=True, allow_namedtuple_subclass=False):
118  """Checks whether an entity is allowed for use in graph mode.
119
120  Examples of allowed entities include all members of the tensorflow
121  package.
122
123  Args:
124    o: A Python entity.
125    check_call_override: Reserved for internal use. When set to `False`, it
126      disables the rule according to which classes are allowed if their
127      __call__ method is allowed.
128    allow_namedtuple_subclass: Reserved for internal use. When `True`,
129      namedtuple subclasses are not allowed.
130
131  Returns:
132    Boolean
133  """
134  # TODO(b/120224672): Fix this.
135  if isinstance(o, functools.partial):
136    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
137    # functools.partial objects do not have a __module__ attribute.
138    m = functools
139  else:
140    m = tf_inspect.getmodule(o)
141
142  # Examples of callables that lack a __module__ property include builtins.
143  if hasattr(m, '__name__'):
144    for rule in config.CONVERSION_RULES:
145      action = rule.get_action(m)
146      if action == config.Action.CONVERT:
147        logging.log(2, 'Not allowed: %s: %s', o, rule)
148        return False
149      elif action == config.Action.DO_NOT_CONVERT:
150        logging.log(2, 'Allowlisted: %s: %s', o, rule)
151        return True
152
153  # The check for __code__ below is because isgeneratorfunction crashes
154  # without one.
155  if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
156    logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
157    return True
158
159  if (check_call_override and not tf_inspect.isclass(o) and
160      hasattr(o, '__call__')):
161    # Callable objects: allowed if their __call__ method is.
162    # The type check avoids infinite recursion around the __call__ method
163    # of function objects.
164    if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__):  # pylint: disable=unidiomatic-typecheck
165      logging.log(2, 'Allowlisted: %s: object __call__ allowed', o)
166      return True
167
168  owner_class = None
169  if tf_inspect.ismethod(o):
170    # Methods of allowed classes are also allowed, even if they are
171    # bound via user subclasses.
172    #
173    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
174    # defined as below. `tf.Foo` is allowed. Then `baz.bar` is also
175    # allowed.
176    #
177    #   class Custom(tf.Foo):
178    #     pass
179    #
180    #   baz = Custom()
181    #
182    # For the example above, if `Custom` did overload `bar`, then it would no
183    # longer be allowed.
184
185    owner_class = inspect_utils.getmethodclass(o)
186    if owner_class is function.TfMethodTarget:
187      owner_class = o.__self__.target_class
188    if owner_class is not None:
189      if issubclass(owner_class, unittest.TestCase):
190        logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o)
191        return True
192
193      owner_class = inspect_utils.getdefiningclass(o, owner_class)
194      if is_allowlisted(
195          owner_class,
196          check_call_override=False,
197          allow_namedtuple_subclass=True):
198        logging.log(2, 'Allowlisted: %s: owner is allowed %s', o,
199                    owner_class)
200        return True
201
202  if inspect_utils.isnamedtuple(o):
203    # Due to the way they're constructed, namedtuple types cannot be converted
204    # because they don't expose source code. But we assume they are safe for
205    # graph mode since they are just containers.
206    if allow_namedtuple_subclass:
207      if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
208        logging.log(2, 'Allowlisted: %s: named tuple', o)
209        return True
210    else:
211      logging.log(2, 'Allowlisted: %s: named tuple or subclass', o)
212      return True
213
214  logging.log(2, 'Not allowed: %s: default rule', o)
215  return False
216
217
218def is_in_allowlist_cache(entity, options):
219  try:
220    return _ALLOWLIST_CACHE.has(entity, options)
221  except TypeError:
222    # Catch-all for entities that are unhashable or don't allow weakrefs.
223    return False
224
225
226def cache_allowlisted(entity, options):
227  try:
228    _ALLOWLIST_CACHE[entity][options] = True
229  except TypeError:
230    # Catch-all for entities that are unhashable or don't allow weakrefs.
231    pass
232