1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core conversion logic, serves as main point of access.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import inspect 23import sys 24import unittest 25 26from tensorflow.python.autograph.core import config 27from tensorflow.python.autograph.pyct import cache 28from tensorflow.python.autograph.pyct import inspect_utils 29from tensorflow.python.autograph.utils import ag_logging as logging 30from tensorflow.python.eager import function 31from tensorflow.python.util import tf_inspect 32 33 34_ALLOWLIST_CACHE = cache.UnboundInstanceCache() 35 36 37def _is_of_known_loaded_module(f, module_name): 38 mod = sys.modules.get(module_name, None) 39 if mod is None: 40 return False 41 if any(v is not None for v in mod.__dict__.values() if f is v): 42 return True 43 return False 44 45 46def _is_known_loaded_type(f, module_name, entity_name): 47 """Tests whether the function or method is an instance of a known type.""" 48 if (module_name not in sys.modules or 49 not hasattr(sys.modules[module_name], entity_name)): 50 return False 51 type_entity = getattr(sys.modules[module_name], entity_name) 52 if isinstance(f, type_entity): 53 # The method if of this type. Example: 54 # 55 # o = ClassType() 56 # function(o.method)() 57 return True 58 # Note: inspect is required here, to avoid unpacking tf.function decorators. 59 if inspect.ismethod(f): 60 # The unbound method if of this type. Example: 61 # 62 # class ClassType: 63 # @function 64 # def method(self): 65 # ... 66 # o = ClassType() 67 # o.method() 68 if isinstance(f.__func__, type_entity): 69 return True 70 return False 71 72 73def is_unsupported(o): 74 """Checks whether an entity is supported by AutoGraph at all.""" 75 76 # TODO(b/122265385): Remove this bypass. 77 if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or 78 _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')): 79 logging.warn( 80 '{} appears to be decorated by wrapt, which is not yet supported' 81 ' by AutoGraph. The function will run as-is.' 82 ' You may still apply AutoGraph before the wrapt decorator.'.format(o)) 83 logging.log(2, 'Permanently allowed: %s: wrapt decorated', o) 84 return True 85 86 if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'): 87 logging.log(2, 'Permanently allowed: %s: lru_cache', o) 88 return True 89 90 # Constructors are permanently allowed. 91 # TODO(mdan): Toggle as experimental feature instead. 92 # TODO(b/124016764): Remove this limitation. 93 if inspect_utils.isconstructor(o): 94 logging.log(2, 'Permanently allowed: %s: constructor', o) 95 return True 96 97 # Other built-in modules are permanently allowed. 98 # TODO(mdan): Figure out how to do this consistently for all stdlib modules. 99 if any( 100 _is_of_known_loaded_module(o, m) 101 for m in ('collections', 'pdb', 'copy', 'inspect', 're')): 102 logging.log(2, 'Permanently allowed: %s: part of builtin module', o) 103 return True 104 105 # Custom ops and kernels are also permanently allowed. 106 # See tensorflow.framework.load_library. 107 if (hasattr(o, '__module__') and 108 hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')): 109 logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o) 110 return True 111 112 return False 113 114 115# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True. 116def is_allowlisted( 117 o, check_call_override=True, allow_namedtuple_subclass=False): 118 """Checks whether an entity is allowed for use in graph mode. 119 120 Examples of allowed entities include all members of the tensorflow 121 package. 122 123 Args: 124 o: A Python entity. 125 check_call_override: Reserved for internal use. When set to `False`, it 126 disables the rule according to which classes are allowed if their 127 __call__ method is allowed. 128 allow_namedtuple_subclass: Reserved for internal use. When `True`, 129 namedtuple subclasses are not allowed. 130 131 Returns: 132 Boolean 133 """ 134 # TODO(b/120224672): Fix this. 135 if isinstance(o, functools.partial): 136 # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since 137 # functools.partial objects do not have a __module__ attribute. 138 m = functools 139 else: 140 m = tf_inspect.getmodule(o) 141 142 # Examples of callables that lack a __module__ property include builtins. 143 if hasattr(m, '__name__'): 144 for rule in config.CONVERSION_RULES: 145 action = rule.get_action(m) 146 if action == config.Action.CONVERT: 147 logging.log(2, 'Not allowed: %s: %s', o, rule) 148 return False 149 elif action == config.Action.DO_NOT_CONVERT: 150 logging.log(2, 'Allowlisted: %s: %s', o, rule) 151 return True 152 153 # The check for __code__ below is because isgeneratorfunction crashes 154 # without one. 155 if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o): 156 logging.log(2, 'Allowlisted: %s: generator functions are not converted', o) 157 return True 158 159 if (check_call_override and not tf_inspect.isclass(o) and 160 hasattr(o, '__call__')): 161 # Callable objects: allowed if their __call__ method is. 162 # The type check avoids infinite recursion around the __call__ method 163 # of function objects. 164 if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck 165 logging.log(2, 'Allowlisted: %s: object __call__ allowed', o) 166 return True 167 168 owner_class = None 169 if tf_inspect.ismethod(o): 170 # Methods of allowed classes are also allowed, even if they are 171 # bound via user subclasses. 172 # 173 # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is 174 # defined as below. `tf.Foo` is allowed. Then `baz.bar` is also 175 # allowed. 176 # 177 # class Custom(tf.Foo): 178 # pass 179 # 180 # baz = Custom() 181 # 182 # For the example above, if `Custom` did overload `bar`, then it would no 183 # longer be allowed. 184 185 owner_class = inspect_utils.getmethodclass(o) 186 if owner_class is function.TfMethodTarget: 187 owner_class = o.__self__.target_class 188 if owner_class is not None: 189 if issubclass(owner_class, unittest.TestCase): 190 logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o) 191 return True 192 193 owner_class = inspect_utils.getdefiningclass(o, owner_class) 194 if is_allowlisted( 195 owner_class, 196 check_call_override=False, 197 allow_namedtuple_subclass=True): 198 logging.log(2, 'Allowlisted: %s: owner is allowed %s', o, 199 owner_class) 200 return True 201 202 if inspect_utils.isnamedtuple(o): 203 # Due to the way they're constructed, namedtuple types cannot be converted 204 # because they don't expose source code. But we assume they are safe for 205 # graph mode since they are just containers. 206 if allow_namedtuple_subclass: 207 if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__): 208 logging.log(2, 'Allowlisted: %s: named tuple', o) 209 return True 210 else: 211 logging.log(2, 'Allowlisted: %s: named tuple or subclass', o) 212 return True 213 214 logging.log(2, 'Not allowed: %s: default rule', o) 215 return False 216 217 218def is_in_allowlist_cache(entity, options): 219 try: 220 return _ALLOWLIST_CACHE.has(entity, options) 221 except TypeError: 222 # Catch-all for entities that are unhashable or don't allow weakrefs. 223 return False 224 225 226def cache_allowlisted(entity, options): 227 try: 228 _ALLOWLIST_CACHE[entity][options] = True 229 except TypeError: 230 # Catch-all for entities that are unhashable or don't allow weakrefs. 231 pass 232