1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module contains the user-facing API for AutoGraph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import functools 24import os 25import pdb 26import sys 27 28from enum import Enum 29 30# pylint:disable=g-bad-import-order 31import numpy as np 32import six 33# pylint:enable=g-bad-import-order 34 35 36from tensorflow.python.autograph.core import converter 37from tensorflow.python.autograph.impl import conversion 38from tensorflow.python.autograph.operators import py_builtins 39from tensorflow.python.autograph.pyct import compiler 40from tensorflow.python.autograph.pyct import errors 41from tensorflow.python.autograph.pyct import inspect_utils 42from tensorflow.python.autograph.utils import ag_logging as logging 43from tensorflow.python.autograph.utils import py_func 44from tensorflow.python.framework import tensor_util 45from tensorflow.python.util import nest 46from tensorflow.python.util import tf_decorator 47from tensorflow.python.util import tf_inspect 48from tensorflow.python.util.tf_export import tf_export 49 50 51def is_autograph_strict_conversion_mode(): 52 return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0 53 54 55# TODO(mdan): Properly document the type hints. 56# TODO(mdan): Reduce the type hint information to (module, type). 57# (currently we require (module + class name, type)) 58 59 60# TODO(mdan): This should behave like to_graph (e.g. convert statically). 61# TODO(znado): Make an alias so can write Verbosity directly without needing 62# to write converter. 63def convert( 64 recursive=False, 65 optional_features=converter.Feature.ALL): 66 """Decorator that compiles a function to use TensorFlow ops. 67 68 The decorator is dynamic - it recompiles the target whenever the decorated 69 function is called. This means the parameter values are known at conversion. 70 It also means that repeated calls with different types of parameters will be 71 correctly processed. 72 73 Args: 74 recursive: bool, whether to recursively convert any functions or classes 75 that the converted function may use. 76 optional_features: converted.Feature, allows toggling optional or 77 experimental features. When set to None, only the core features are 78 enabled. 79 80 Returns: 81 Callable, a decorator that converts the given function into an equivalent 82 function that uses TensorFlow ops. 83 """ 84 85 def decorator(f): 86 """Decorator implementation.""" 87 88 @functools.wraps(f) 89 def wrapper(*args, **kwargs): 90 return converted_call( 91 f, None, 92 converter.ConversionOptions( 93 recursive=recursive, 94 force_conversion=True, 95 optional_features=optional_features, 96 ), args, kwargs) 97 98 wrapper = tf_decorator.make_decorator(f, wrapper) 99 100 # Sometimes the decorator is just desugared, making it impossible to detect. 101 # This attribute makes detection easier. 102 setattr(wrapper, '__ag_compiled', True) 103 return wrapper 104 105 return decorator 106 107 108class RunMode(Enum): 109 """Specifies the way a converted function or method should be executed in TF. 110 111 Attributes: 112 * GRAPH: Call this function directly, as-is. This is suitable for functions 113 that were already designed for TF graphs and contain ops. 114 * PY_FUNC: Wrap this function into a py_func op. This is suitable for code 115 that will only run correctly in Python, for example code that renders 116 to the display, reads keyboard input, etc. 117 """ 118 GRAPH = 1 119 PY_FUNC = 2 120 121 122def do_not_convert_internal(f): 123 """Decorator that marks internal functions which do not need conversion.""" 124 setattr(f, '__ag_compiled', True) 125 return f 126 127 128def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): 129 """Decorator that suppresses the conversion of a function. 130 131 See also: docs/pyfunc_dtypes.md 132 133 Args: 134 run_as: RunMode, specifies how to use the function in TensorFlow. 135 return_dtypes: Optional[Iterable[ Union[tf.DType, 136 utils.py_func.MatchDType]]], the return data types of the converted 137 function, if run_as is RunMode.PY_FUNC. Ignored otherwise. May be set to 138 None if the function has no return values. 139 140 Returns: 141 Callable, a decorator that wraps the original function. 142 """ 143 144 def decorator(f): 145 """Decorator implementation.""" 146 147 @functools.wraps(f) 148 def graph_wrapper(*args, **kwargs): 149 return f(*args, **kwargs) 150 151 @functools.wraps(f) 152 def py_func_wrapper(*args, **kwargs): 153 if kwargs: 154 raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') 155 # TODO(mdan): Add support for kwargs. 156 return py_func.wrap_py_func( 157 f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) 158 159 if run_as == RunMode.GRAPH: 160 wrapper = graph_wrapper 161 elif run_as == RunMode.PY_FUNC: 162 wrapper = py_func_wrapper 163 else: 164 raise ValueError('unknown value for run_as: %s' % run_as) 165 166 setattr(wrapper, '__ag_compiled', True) 167 return wrapper 168 169 return decorator 170 171 172def _call_unconverted(f, args, kwargs): 173 """Calls the original function without converting with AutoGraph.""" 174 if inspect_utils.istfmethodtarget(f): 175 return f.__self__.call(args, kwargs) 176 177 return f(*args, **kwargs) 178 179 180def _is_known_loaded_type(f, module_name, entity_name): 181 """Tests whether the function or method is an instance of a known type.""" 182 if (module_name not in sys.modules or 183 not hasattr(sys.modules[module_name], entity_name)): 184 return False 185 type_entity = getattr(sys.modules[module_name], entity_name) 186 if isinstance(f, type_entity): 187 # The method if of this type. Example: 188 # 189 # o = ClassType() 190 # function(o.method)() 191 return True 192 if tf_inspect.ismethod(f): 193 f = six.get_unbound_function(f) 194 # The the unbound method if of this type. Example: 195 # 196 # class ClassType: 197 # @function 198 # def method(self): 199 # ... 200 # o = ClassType() 201 # o.method() 202 if isinstance(f, type_entity): 203 return True 204 return False 205 206 207def converted_call(f, owner, options, args, kwargs): 208 """Compiles a function call inline. For internal use only.""" 209 logging.log(1, 210 'Converted call: %s; owner: %s\n args: %s\n kwargs: %s\n', 211 f, owner, args, kwargs) 212 213 if owner is not None: 214 if not isinstance(f, str): 215 raise ValueError( 216 'When owner is specified, the function name must be specified as' 217 ' a string: {}'.format(f)) 218 219 # Special case when the owner is a 'super' object. In that case lookups of 220 # dynamic attributes won't work. See 221 # inspect_utils.SuperWrapperForDynamicAttrs. 222 if isinstance(owner, super): 223 owner = inspect_utils.SuperWrapperForDynamicAttrs(owner) 224 225 f = getattr(owner, f) 226 227 if inspect_utils.isbuiltin(f): 228 return py_builtins.overload_of(f)(*args, **kwargs) 229 230 if _is_known_loaded_type(f, 'weakref', 'ref'): 231 logging.log(2, 'Permanently whitelisted: %s: weakref', f) 232 return _call_unconverted(f, args, kwargs) 233 234 # TODO(b/122265385): Remove this bypass. 235 if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or 236 _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')): 237 logging.warn( 238 'Entity {} appears to be decorated by wrapt, which is not yet supported' 239 ' by AutoGraph. The function will be called without transformation.' 240 ' You may however apply AutoGraph before the decorator.'.format(f)) 241 logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f) 242 return _call_unconverted(f, args, kwargs) 243 244 # Constructors are permanently whitelisted. 245 # TODO(mdan): Toggle as experimental feature instead. 246 # TODO(b/124016764): Remove this limitation. 247 if tf_inspect.isclass(f): 248 logging.log(2, 'Permanently whitelisted: %s: constructor', f) 249 return _call_unconverted(f, args, kwargs) 250 251 # Other built-in modules are permanently whitelisted. 252 # TODO(mdan): Figure out how to do this consistently for all stdlib modules. 253 # Note: TF linter disallows importing inspect. 254 if any(f in m.__dict__.values() 255 for m in (collections, pdb, copy, tf_inspect._inspect)): # pylint:disable=protected-access 256 logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f) 257 return _call_unconverted(f, args, kwargs) 258 259 if not options.force_conversion and conversion.is_whitelisted_for_graph(f): 260 return _call_unconverted(f, args, kwargs) 261 262 # internal_convert_user_code is for example turned off when issuing a dynamic 263 # call conversion from generated code while in nonrecursive mode. In that 264 # case we evidently don't want to recurse, but we still have to convert 265 # things like builtins. 266 if not options.internal_convert_user_code: 267 return _call_unconverted(f, args, kwargs) 268 269 # TODO(mdan): Move this entire block inside to_graph. 270 try: # Begin of transformation error guards 271 272 # Unwrap functools.partial objects 273 # TODO(mdan): Consider sharing unwrapping logic with tf_inspect. 274 while isinstance(f, functools.partial): 275 args = f.args + args 276 new_kwargs = {} 277 if f.keywords is not None: 278 new_kwargs.update(f.keywords) 279 new_kwargs.update(kwargs) 280 kwargs = new_kwargs 281 f = f.func 282 283 if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): 284 # Regular functions 285 target_entity = f 286 arg_map_target = f 287 f_self = inspect_utils.getmethodself(f) 288 289 # TODO(b/119246461): This may be more elegantly handled using __get__? 290 if f_self is not None: 291 effective_args = (f_self,) + args 292 else: 293 effective_args = args 294 295 elif tf_inspect.isclass(f): 296 # Constructors 297 # Note: Until we support class constructurs, and enable whole-class 298 # conversion with an experimental flag, this branch is dead code. 299 # TODO(mdan): Consider removing unless there is a compelling use case. 300 target_entity = f 301 arg_map_target = f.__init__ 302 effective_args = args 303 304 elif hasattr(f, '__call__') and hasattr(f, '__class__'): 305 # Callable objects 306 target_entity = f.__call__ 307 arg_map_target = f.__call__ 308 effective_args = (f,) + args 309 310 else: 311 target_entity = f 312 raise NotImplementedError('unknown callable type "%s"' % type(f)) 313 314 arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) 315 arg_types = {} 316 for name, arg in arg_values.items(): 317 arg_class = arg.__class__ 318 arg_types[name] = (arg_class.__name__, arg_class) 319 320 converted_f = to_graph( 321 target_entity, 322 recursive=options.recursive, 323 arg_values=arg_values, 324 arg_types=arg_types, 325 experimental_optional_features=options.optional_features) 326 327 if logging.has_verbosity(2): 328 logging.log(2, 'Defaults of %s : %s', converted_f, 329 converted_f.__defaults__) 330 callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs) 331 formatted_callargs = '\n'.join( 332 ' {}: {}'.format(k, v) for k, v in callargs.items()) 333 logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs) 334 335 # TODO(mdan): Reduce this list. 336 except (errors.AutoGraphError, AssertionError, AttributeError, IndexError, 337 KeyError, NameError, NotImplementedError, SyntaxError, TypeError, 338 ValueError, IOError) as e: 339 340 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) 341 342 if is_autograph_strict_conversion_mode(): 343 raise 344 345 logging.warn( 346 'Entity %s could not be transformed and will be staged without change.' 347 ' Error details can be found in the logs when running with the env' 348 ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the' 349 ' AutoGraph team. Cause: %s', target_entity, e) 350 351 return _call_unconverted(f, args, kwargs) 352 353 result = converted_f(*effective_args, **kwargs) 354 355 # The converted function's closure is simply inserted into the function's 356 # module __dict__. Since modules are permanently cached, that results in 357 # leaking the entire closure. 358 # Normally, it's not safe to delete the module because that may release said 359 # closure as well. However, in the case of converted_call we are certain the 360 # function will not be executed again, so the closure should no longer be 361 # needed so long as the function doesn't return any executable code. 362 # TODO(mdan): Attach the closure properly, using cells. 363 if all(map(_is_not_callable, nest.flatten(result))): 364 del sys.modules[converted_f.__module__] 365 366 return result 367 368 369def _is_not_callable(obj): 370 # TODO(brianklee): Handle case when obj is a tensor dependent on a py_func. 371 if isinstance(obj, (int, float, complex, str, bool)): 372 return True 373 if isinstance(obj, (np.ndarray, np.generic)): 374 return True 375 if tensor_util.is_tensor(obj): 376 return True 377 return False 378 379 380@tf_export('autograph.to_graph') 381def to_graph(entity, 382 recursive=True, 383 arg_values=None, 384 arg_types=None, 385 experimental_optional_features=converter.Feature.ALL): 386 """Converts a Python entity into a TensorFlow graph. 387 388 Also see: `tf.autograph.to_code`, `tf.function`. 389 390 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts 391 Python code to TensorFlow graph code. It does not implement any caching, 392 variable management or create any actual ops, and is best used where greater 393 control over the generated TensorFlow graph is desired. Another difference 394 from `tf.function` is that `to_graph` will not wrap the graph into a 395 TensorFlow function or a Python callable. Internally, `tf.function` uses 396 `to_graph`. 397 398 _Example Usage_ 399 400 ```python 401 def foo(x): 402 if x > 0: 403 y = x * x 404 else: 405 y = -x 406 return y 407 408 converted_foo = to_graph(foo) 409 410 x = tf.constant(1) 411 y = converted_foo(x) # converted_foo is a TensorFlow Op-like. 412 assert is_tensor(y) 413 ``` 414 415 Supported Python entities include: 416 * functions 417 * classes 418 * object methods 419 420 Functions are converted into new functions with converted code. 421 422 Classes are converted by generating a new class whose methods use converted 423 code. 424 425 Methods are converted into unbound function that have an additional first 426 argument called `self`. 427 428 Args: 429 entity: Python callable or class to convert. 430 recursive: Whether to recursively convert any functions that the 431 converted function may call. 432 arg_values: Optional dict of value hints for symbols including 433 function arguments mapping string names to actual values. For example, 434 `arg_values={'a': 1}` will map the variable `a` to the value `1`. 435 arg_types: Optional dict of type hints for symbols including function 436 arguments. Type hints allow specifying just the type of a variable, rather 437 than a specific value. 438 experimental_optional_features: `None`, a tuple of, or a single 439 `tf.autograph.experimental.Feature` value. Controls the use of 440 optional features in the conversion process. 441 442 Returns: 443 Same as `entity`, the converted Python function or class. 444 445 Raises: 446 ValueError: If the entity could not be converted. 447 """ 448 try: 449 program_ctx = converter.ProgramContext( 450 options=converter.ConversionOptions( 451 recursive=recursive, 452 optional_features=experimental_optional_features), 453 autograph_module=tf_inspect.getmodule(to_graph)) 454 nodes, name, namespace = conversion.entity_to_graph(entity, program_ctx, 455 arg_values, arg_types) 456 457 compiled_module, _ = compiler.ast_to_object( 458 nodes, 459 source_prefix=program_ctx.required_imports, 460 include_source_map=True) 461 462 # The compiled code should see everything the entry entity saw. 463 # TODO(mdan): This might not work well if the call tree spans modules? 464 for key, val in namespace.items(): 465 # Avoid overwriting entities that have been transformed. 466 if key not in compiled_module.__dict__: 467 compiled_module.__dict__[key] = val 468 compiled = getattr(compiled_module, name) 469 470 if hasattr(entity, '__defaults__'): 471 logging.log(3, 'Default args mapping: %s has: %s', entity, 472 entity.__defaults__) 473 compiled.__defaults__ = entity.__defaults__ 474 else: 475 logging.log(3, 'Default args mapping: %s has no __defaults__', entity) 476 477 logging.log(3, 'Namespace of %s includes: %s', compiled, 478 compiled_module.__dict__.keys()) 479 480 if hasattr(compiled, '__globals__'): 481 # Remove self to avoid circular references. This will probably only work 482 # so long as the function is not reentrant. 483 del compiled.__globals__[name] 484 485 # Need this so the source_mapping attribute is available for the context 486 # manager to access for runtime errors. 487 # 488 # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' 489 # symbol to the compiled module. 490 # TODO(mdan): Record this statically in the generated code. 491 # TODO(mdan): Rename this attribute to 'autograph_info__' 492 source_map_attribute_name = 'ag_source_map' 493 if getattr(compiled, source_map_attribute_name, None) is not None: 494 # TODO(znado): change input problem errors into TransformError 495 raise ValueError('cannot convert %s because is has an attribute ' 496 '"%s", which is reserved for AutoGraph.' % 497 (compiled, source_map_attribute_name)) 498 setattr(compiled, source_map_attribute_name, 499 compiled_module.__dict__['ag_source_map__']) 500 501 return compiled 502 except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: 503 errors.report_internal_error(entity, e) 504 505 506@tf_export('autograph.to_code') 507def to_code(entity, 508 recursive=True, 509 arg_values=None, 510 arg_types=None, 511 indentation=' ', 512 experimental_optional_features=converter.Feature.ALL): 513 """Similar to `to_graph`, but returns Python source code as a string. 514 515 Also see: `tf.autograph.to_graph`. 516 517 `to_graph` returns the Python source code that can be used to generate a 518 TensorFlow graph that is functionally identical to the input Python code. 519 520 Args: 521 entity: Python callable or class to convert. 522 recursive: Whether to recursively convert any functions that the 523 converted function may call. 524 arg_values: Optional dict of value hints for symbols including 525 function arguments mapping string names to actual values. For example, 526 `arg_values={'a': 1}` will map the variable `a` to the value `1`. 527 arg_types: Optional dict of type hints for symbols including function 528 arguments. Type hints allow specifying just the type of a variable, rather 529 than a specific value. 530 indentation: The string to use for indenting. Typically two or four spaces, 531 or just the tab character. 532 experimental_optional_features: `None`, a tuple of, or a single 533 `tf.autograph.experimental.Feature` value. Controls the use of 534 optional features in the conversion process. 535 536 Returns: 537 The converted code as string. 538 """ 539 program_ctx = converter.ProgramContext( 540 options=converter.ConversionOptions( 541 recursive=recursive, 542 optional_features=experimental_optional_features), 543 autograph_module=tf_inspect.getmodule(to_graph)) 544 nodes, _, _ = conversion.entity_to_graph(entity, program_ctx, arg_values, 545 arg_types) 546 547 code = compiler.ast_to_source(nodes, indentation) 548 549 return program_ctx.required_imports + '\n\n' + code 550