1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module contains the user- and codegen-facing API for AutoGraph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import imp 23import inspect 24import os 25import sys 26import textwrap 27import traceback 28 29import six 30 31from tensorflow.python.autograph import operators 32from tensorflow.python.autograph import utils 33from tensorflow.python.autograph.converters import asserts 34from tensorflow.python.autograph.converters import break_statements 35from tensorflow.python.autograph.converters import call_trees 36from tensorflow.python.autograph.converters import conditional_expressions 37from tensorflow.python.autograph.converters import continue_statements 38from tensorflow.python.autograph.converters import control_flow 39from tensorflow.python.autograph.converters import directives 40from tensorflow.python.autograph.converters import functions 41from tensorflow.python.autograph.converters import lists 42from tensorflow.python.autograph.converters import logical_expressions 43from tensorflow.python.autograph.converters import return_statements 44from tensorflow.python.autograph.converters import slices 45from tensorflow.python.autograph.converters import variables 46from tensorflow.python.autograph.core import ag_ctx 47from tensorflow.python.autograph.core import converter 48from tensorflow.python.autograph.core import function_wrappers 49from tensorflow.python.autograph.core import unsupported_features_checker 50from tensorflow.python.autograph.impl import conversion 51from tensorflow.python.autograph.lang import special_functions 52from tensorflow.python.autograph.operators import py_builtins 53from tensorflow.python.autograph.pyct import anno 54from tensorflow.python.autograph.pyct import cfg 55from tensorflow.python.autograph.pyct import error_utils 56from tensorflow.python.autograph.pyct import errors 57from tensorflow.python.autograph.pyct import inspect_utils 58from tensorflow.python.autograph.pyct import origin_info 59from tensorflow.python.autograph.pyct import qual_names 60from tensorflow.python.autograph.pyct import transpiler 61from tensorflow.python.autograph.pyct.static_analysis import activity 62from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 63from tensorflow.python.autograph.utils import ag_logging as logging 64from tensorflow.python.eager import function 65from tensorflow.python.framework import errors_impl 66from tensorflow.python.util import tf_decorator 67from tensorflow.python.util import tf_inspect 68from tensorflow.python.util import tf_stack 69from tensorflow.python.util.tf_export import tf_export 70 71 72def is_autograph_strict_conversion_mode(): 73 return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0 74 75 76# 77# Error handling 78# 79 80 81# TODO(mdan): Export this symbol. 82class AutoGraphError(errors.PyCTError): 83 """Base class for all AutoGraph exceptions.""" 84 pass 85 86 87class ConversionError(AutoGraphError): 88 """Raised during the conversion process.""" 89 pass 90 91 92class StagingError(AutoGraphError): 93 """Raised during the staging (i.e. Python execution) of converted code.""" 94 pass 95 96 97class _ErrorMetadata(error_utils.ErrorMetadataBase): 98 """AutoGraph-specific error metadata. See base class.""" 99 100 def create_exception(self, source_error): 101 preferred_type = type(source_error) 102 if issubclass(preferred_type, errors_impl.OpError): 103 # Best-effort unpacking of OpError exceptions. 104 # TODO(mdan): Use a mechanism that is more future-proof. 105 init_argspec = tf_inspect.getfullargspec(preferred_type.__init__) 106 message = self.get_message() 107 init_args = tuple(init_argspec.args) 108 # At the time of this writing, TF errors either take 3 or 4 arguments, 109 # the argument '*args' may or may not be used. 110 if init_args == ('self', 'node_def', 'op', 'message'): 111 return preferred_type(source_error.node_def, source_error.op, message, 112 source_error.experimental_payloads) 113 114 elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError, 115 StagingError, errors_impl.InaccessibleTensorError, 116 errors_impl.OperatorNotAllowedInGraphError): 117 return preferred_type(self.get_message()) 118 119 exc = super(_ErrorMetadata, self).create_exception(source_error) 120 if exc is not None: 121 return exc 122 123 # Note: While changing an error's message property to change the message it 124 # displays will probably work a lot of times, there is no standard way in 125 # Python to do that. The safest way is therefore to create a new exception. 126 # For user defined exceptions, we could define an interface that allowed 127 # them to work under this mechanism. 128 return StagingError(self.get_message()) 129 130 131def _attach_error_metadata(e, f): 132 """Augments an error with the metadata necessary for rewrite.""" 133 if hasattr(e, 'ag_pass_through'): 134 return 135 136 metadata = getattr(e, 'ag_error_metadata', None) 137 source_map = f.ag_source_map 138 139 if metadata is None: 140 logging.log(1, 'Caught error in user callable %s', f, exc_info=True) 141 message = '{}: {}'.format(e.__class__.__name__, e) 142 else: 143 message = None 144 145 cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] 146 147 e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map, 148 __file__) 149 150 151class StackTraceMapper(tf_stack.StackTraceMapper): 152 """Remaps generated code to code it originated from.""" 153 154 def __init__(self, converted_fn): 155 super().__init__() 156 self._source_map = converted_fn.ag_source_map 157 # This may be called repeatedly: once on entry, by the superclass, then by 158 # each child context manager. 159 self._cached_map = None 160 161 def get_effective_source_map(self): 162 if self._cached_map is not None: 163 return self._cached_map 164 165 parent_map = self.parent.get_effective_source_map() 166 167 effective_source_map = {} 168 for loc, origin in self._source_map.items(): 169 effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename, 170 origin.loc.lineno, 171 origin.function_name) 172 173 for key, value in parent_map.items(): 174 filename, lineno, _ = value 175 value_loc = origin_info.LineLocation(filename=filename, lineno=lineno) 176 if value_loc in self._source_map: 177 origin = self._source_map[value_loc] 178 effective_source_map[key] = (origin.loc.filename, origin.loc.lineno, 179 origin.function_name) 180 else: 181 effective_source_map[key] = value 182 183 self._cached_map = effective_source_map 184 return effective_source_map 185 186 187# 188# Actual source code transformation 189# 190 191 192class PyToTF(transpiler.PyToPy): 193 """The TensorFlow AutoGraph transformer.""" 194 195 def __init__(self): 196 super(PyToTF, self).__init__() 197 self._extra_locals = None 198 199 def get_transformed_name(self, node): 200 return 'tf__' + super(PyToTF, self).get_transformed_name(node) 201 202 def get_extra_locals(self): 203 if self._extra_locals is None: 204 # TODO(mdan): Move into core or replace with an actual importable module. 205 # Craft a module that exposes the external API as well as certain 206 # internal modules. 207 ag_internal = imp.new_module('autograph') 208 ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__) 209 ag_internal.ConversionOptions = converter.ConversionOptions 210 ag_internal.STD = converter.STANDARD_OPTIONS 211 ag_internal.Feature = converter.Feature 212 ag_internal.utils = utils 213 ag_internal.FunctionScope = function_wrappers.FunctionScope 214 ag_internal.with_function_scope = function_wrappers.with_function_scope 215 # TODO(mdan): Add safeguards against name clashes. 216 # We don't want to create a submodule because we want the operators to be 217 # accessible as ag__.<operator> 218 ag_internal.__dict__.update(special_functions.__dict__) 219 ag_internal.__dict__.update(operators.__dict__) 220 221 self._extra_locals = {'ag__': ag_internal} 222 return self._extra_locals 223 224 def get_caching_key(self, ctx): 225 return ctx.options 226 227 def initial_analysis(self, node, ctx): 228 graphs = cfg.build(node) 229 node = qual_names.resolve(node) 230 node = activity.resolve(node, ctx, None) 231 node = reaching_definitions.resolve(node, ctx, graphs) 232 anno.dup( 233 node, 234 { 235 anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, 236 }, 237 ) 238 return node 239 240 def transform_ast(self, node, ctx): 241 unsupported_features_checker.verify(node) 242 node = self.initial_analysis(node, ctx) 243 244 node = functions.transform(node, ctx) 245 node = directives.transform(node, ctx) 246 node = break_statements.transform(node, ctx) 247 if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): 248 node = asserts.transform(node, ctx) 249 # Note: sequencing continue canonicalization before for loop one avoids 250 # dealing with the extra loop increment operation that the for 251 # canonicalization creates. 252 node = continue_statements.transform(node, ctx) 253 node = return_statements.transform(node, ctx) 254 if ctx.user.options.uses(converter.Feature.LISTS): 255 node = lists.transform(node, ctx) 256 node = slices.transform(node, ctx) 257 node = call_trees.transform(node, ctx) 258 node = control_flow.transform(node, ctx) 259 node = conditional_expressions.transform(node, ctx) 260 node = logical_expressions.transform(node, ctx) 261 node = variables.transform(node, ctx) 262 return node 263 264 265def _convert_actual(entity, program_ctx): 266 """Applies AutoGraph to entity.""" 267 268 # TODO(mdan): Put these extra fields inside __autograph_info__. 269 if not hasattr(entity, '__code__'): 270 raise ValueError('Cannot apply autograph to a function that doesn\'t ' 271 'expose a __code__ object. If this is a @tf.function,' 272 ' try passing f.python_function instead.') 273 274 transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx) 275 276 assert not hasattr(transformed, 'ag_module') 277 assert not hasattr(transformed, 'ag_source_map') 278 transformed.ag_module = module 279 transformed.ag_source_map = source_map 280 return transformed 281 282 283# 284# Generated code support 285# 286 287 288def autograph_artifact(entity, extras=None): 289 if inspect.ismethod(entity): 290 setattr(entity.__func__, 'autograph_info__', extras) 291 else: 292 setattr(entity, 'autograph_info__', extras) 293 return entity 294 295 296def is_autograph_artifact(entity): 297 return hasattr(entity, 'autograph_info__') 298 299 300def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): 301 """Converts a function call inline. 302 303 For internal use only. 304 305 Note: The argument list is optimized for readability of generated code, which 306 may look like this: 307 308 ag__.converted_call(f, (arg1, arg2), None, fscope) 309 ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) 310 ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) 311 312 Args: 313 f: The function to convert. 314 args: Tuple, the original positional arguments of f 315 kwargs: Optional[Dict], the original keyword arguments of f 316 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function 317 scope of the converted function in which this call was originally made. 318 options: Optional[converter.ConversionOptions], conversion options. If not 319 specified, the value of caller_fn_scope.callopts is used. Either options 320 or caller_fn_scope must be present. 321 322 Returns: 323 Any, the result of executing a possibly-converted `f` with the given 324 arguments. 325 """ 326 logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, 327 kwargs) 328 329 if options is None: 330 if caller_fn_scope is None: 331 raise ValueError('either caller_fn_scope or options must have a value') 332 options = caller_fn_scope.callopts 333 334 if conversion.is_in_allowlist_cache(f, options): 335 logging.log(2, 'Allowlisted %s: from cache', f) 336 return _call_unconverted(f, args, kwargs, options, False) 337 338 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 339 logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f) 340 return _call_unconverted(f, args, kwargs, options, False) 341 342 if is_autograph_artifact(f): 343 logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) 344 return _call_unconverted(f, args, kwargs, options) 345 346 # If this is a partial, unwrap it and redo all the checks. 347 if isinstance(f, functools.partial): 348 new_kwargs = {} 349 if f.keywords is not None: 350 # Use copy to avoid mutating the underlying keywords. 351 new_kwargs = f.keywords.copy() 352 if kwargs is not None: 353 new_kwargs.update(kwargs) 354 new_args = f.args + args 355 logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, 356 new_kwargs) 357 return converted_call( 358 f.func, 359 new_args, 360 new_kwargs, 361 caller_fn_scope=caller_fn_scope, 362 options=options) 363 364 if inspect_utils.isbuiltin(f): 365 if f is eval: 366 return py_builtins.eval_in_original_context(f, args, caller_fn_scope) 367 if f is super: 368 return py_builtins.super_in_original_context(f, args, caller_fn_scope) 369 if f is globals: 370 return py_builtins.globals_in_original_context(caller_fn_scope) 371 if f is locals: 372 return py_builtins.locals_in_original_context(caller_fn_scope) 373 if kwargs: 374 return py_builtins.overload_of(f)(*args, **kwargs) 375 else: 376 return py_builtins.overload_of(f)(*args) 377 378 if conversion.is_unsupported(f): 379 return _call_unconverted(f, args, kwargs, options) 380 381 if not options.user_requested and conversion.is_allowlisted(f): 382 return _call_unconverted(f, args, kwargs, options) 383 384 # internal_convert_user_code is for example turned off when issuing a dynamic 385 # call conversion from generated code while in nonrecursive mode. In that 386 # case we evidently don't want to recurse, but we still have to convert 387 # things like builtins. 388 if not options.internal_convert_user_code: 389 return _call_unconverted(f, args, kwargs, options) 390 391 try: 392 if inspect.ismethod(f) or inspect.isfunction(f): 393 target_entity = f 394 effective_args = args 395 396 f_self = getattr(f, '__self__', None) 397 if f_self is not None: 398 if isinstance(f_self, function.TfMethodTarget): 399 f_self = f_self.target 400 effective_args = (f_self,) + effective_args 401 402 elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): 403 # Callable objects. Dunder methods have special lookup rules, see: 404 # https://docs.python.org/3/reference/datamodel.html#specialnames 405 # TODO(mdan): Recurse into converted_call to simplify other verifications. 406 # This should be handled in the same way as partials. 407 target_entity = f.__class__.__call__ 408 effective_args = (f,) + args 409 410 else: 411 target_entity = f 412 raise NotImplementedError('unknown callable type "%s"' % type(f)) 413 414 except Exception as e: # pylint:disable=broad-except 415 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) 416 if is_autograph_strict_conversion_mode(): 417 raise 418 return _fall_back_unconverted(f, args, kwargs, options, e) 419 420 if not hasattr(target_entity, '__code__'): 421 logging.log(2, 'Permanently allowed: %s: native binding', target_entity) 422 return _call_unconverted(f, args, kwargs, options) 423 elif (hasattr(target_entity.__code__, 'co_filename') and 424 target_entity.__code__.co_filename == '<string>'): 425 # TODO(mdan): __globals__['txt'] might work in Py3. 426 logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)', 427 target_entity) 428 return _call_unconverted(f, args, kwargs, options) 429 430 try: 431 program_ctx = converter.ProgramContext(options=options) 432 converted_f = _convert_actual(target_entity, program_ctx) 433 if logging.has_verbosity(2): 434 _log_callargs(converted_f, effective_args, kwargs) 435 except Exception as e: # pylint:disable=broad-except 436 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) 437 if is_autograph_strict_conversion_mode(): 438 raise 439 return _fall_back_unconverted(f, args, kwargs, options, e) 440 441 with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter(): 442 try: 443 if kwargs is not None: 444 result = converted_f(*effective_args, **kwargs) 445 else: 446 result = converted_f(*effective_args) 447 except Exception as e: 448 _attach_error_metadata(e, converted_f) 449 raise 450 451 return result 452 453 454def _call_unconverted(f, args, kwargs, options, update_cache=True): 455 """Calls the original function without converting with AutoGraph.""" 456 if update_cache: 457 conversion.cache_allowlisted(f, options) 458 459 if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget): 460 return f.__self__.call(args, kwargs) 461 462 if kwargs is not None: 463 return f(*args, **kwargs) 464 return f(*args) 465 466 467def _fall_back_unconverted(f, args, kwargs, options, exc): 468 """Falls back to calling the function unconverted, in case of error.""" 469 # TODO(mdan): Consider adding an internal metric. 470 warning_template = ( 471 'AutoGraph could not transform %s and will run it as-is.\n' 472 '%s' 473 'Cause: %s\n' 474 'To silence this warning, decorate the function with' 475 ' @tf.autograph.experimental.do_not_convert') 476 if isinstance(exc, errors.UnsupportedLanguageElementError): 477 if not conversion.is_in_allowlist_cache(f, options): 478 logging.warn(warning_template, f, '', exc) 479 else: 480 file_bug_message = ( 481 'Please report this to the TensorFlow team. When filing the bug, set' 482 ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' 483 ' attach the full output.\n') 484 logging.warn(warning_template, f, file_bug_message, exc) 485 486 return _call_unconverted(f, args, kwargs, options) 487 488 489# 490# TensorFlow integration 491# 492 493 494@tf_export('__internal__.autograph.tf_convert', v1=[]) 495def tf_convert(f, ctx, convert_by_default=True, user_requested=False): 496 """Decorator that applies AutoGraph to a function. 497 498 Use in internal APIs. 499 500 This API is suitable for high order functions internal to the TensorFlow API, 501 and more generally any function to which AutoGraph is not applied. 502 503 Guidance: `convert` was a decorator meant for use directly by developers, but 504 most of today's uses go through `tf.function`. `tf_convert` is to be called 505 from high order functions internal to TF. By default, all the internal 506 TensorFlow functions are skipped when AutoGraph processes the code. This may 507 lead to user-supplied functions to be incorrectly skipped as well. 508 `tf_convert` helps avoid that. See the following example for more details. 509 510 ``` 511 =====tf_internal_module.py===== 512 513 def unconverted(input_fn): 514 return input_fn() 515 516 def converted(input_fn): 517 return tf.__internal__.autograph.tf_convert( 518 input_fn, ctx=tf.__internal__.autograph.control_status_ctx())() 519 520 ======user_module.py====== 521 522 @tf.function 523 def foo(input_fn) 524 return unconverted(input_fn) 525 526 @tf.function 527 def bar(input_fn) 528 return converted(input_fn) 529 530 @tf.function(autograph=False) 531 def baz(input_fn) 532 return converted(input_fn) 533 ``` 534 535 The `foo` method above will execute the `input_fn` without autograph 536 conversion, while the `bar` method will run an autographed `input_fn`. The 537 `baz` method will run an unconverted `input_fn`, since `tf_convert` respect 538 the control status context. 539 540 Note that both methods in `tf_internal_module` are skipped by autograph when 541 tracing the `tf.function`. The configuration of whether a module/package 542 should be skipped by autograph is controlled in 543 tensorflow/python/autograph/core/config.py. 544 545 Args: 546 f: Callable. 547 ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. 548 convert_by_default: bool, whether to use AutoGraph when the context doesn't 549 specify. 550 user_requested: bool, whether to ignore the conversion allowlist. See 551 ConversionOptions.user_requested. 552 553 Returns: 554 Either `f or the converted version of `f`. 555 """ 556 557 if is_autograph_artifact(f): 558 return f 559 f_wrapper = f 560 decorators, f = tf_decorator.unwrap(f) 561 562 # TODO(mdan): Grab features from context. 563 # Note: we pass the original context through to convert to properly handle the 564 # following scenario, which can be used inside TF implementations: 565 # 566 # ctx = ag_ctx.control_status_ctx() 567 # @function(autograph=False) # Low-level graph code 568 # def inner_fn(): 569 # # The context is disabled here, but should be enabled in user user_fn 570 # tf_convert(user_fn, ctx=ctx) 571 if ctx.status == ag_ctx.Status.ENABLED: 572 wrapper_factory = convert( 573 recursive=True, user_requested=user_requested, conversion_ctx=ctx) 574 elif ctx.status == ag_ctx.Status.DISABLED: 575 wrapper_factory = do_not_convert 576 elif ctx.status == ag_ctx.Status.UNSPECIFIED: 577 if convert_by_default: 578 wrapper_factory = convert( 579 recursive=True, user_requested=user_requested, conversion_ctx=ctx) 580 else: 581 wrapper_factory = call_with_unspecified_conversion_status 582 else: 583 assert False, 'This switch contains all possible cases!' 584 wrapper = wrapper_factory(f) 585 586 if decorators: 587 wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper) 588 589 return autograph_artifact(wrapper) 590 591 592def call_with_unspecified_conversion_status(func): 593 """Decorator that resets the conversion context to the unspecified status.""" 594 595 def wrapper(*args, **kwargs): 596 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): 597 return func(*args, **kwargs) 598 599 if inspect.isfunction(func) or inspect.ismethod(func): 600 wrapper = functools.update_wrapper(wrapper, func) 601 602 return autograph_artifact(wrapper) 603 604 605def _log_callargs(f, args, kwargs): 606 """Logging helper.""" 607 logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) 608 if not six.PY2: 609 logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) 610 611 if kwargs is not None: 612 callargs = tf_inspect.getcallargs(f, *args, **kwargs) 613 else: 614 callargs = tf_inspect.getcallargs(f, *args) 615 616 formatted_callargs = '\n'.join( 617 ' {}: {}'.format(k, v) for k, v in callargs.items()) 618 logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) 619 620 621# 622# Public API 623# 624 625 626@tf_export('autograph.experimental.do_not_convert') 627def do_not_convert(func=None): 628 """Decorator that suppresses the conversion of a function. 629 630 Args: 631 func: function to decorate. 632 633 Returns: 634 If `func` is not None, returns a `Callable` which is equivalent to 635 `func`, but is not converted by AutoGraph. 636 If `func` is None, returns a decorator that, when invoked with a 637 single `func` argument, returns a `Callable` equivalent to the 638 above case. 639 """ 640 if func is None: 641 return do_not_convert 642 643 def wrapper(*args, **kwargs): 644 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): 645 return func(*args, **kwargs) 646 647 if inspect.isfunction(func) or inspect.ismethod(func): 648 wrapper = functools.update_wrapper(wrapper, func) 649 650 return autograph_artifact(wrapper) 651 652 653# TODO(mdan): Make private. 654def convert(recursive=False, 655 optional_features=None, 656 user_requested=True, 657 conversion_ctx=ag_ctx.NullCtx()): 658 """Decorator that compiles a function to use TensorFlow ops. 659 660 The decorator is dynamic - it recompiles the target whenever the decorated 661 function is called. This means the parameter values are known at conversion. 662 It also means that repeated calls with different types of parameters will be 663 correctly processed. 664 665 Args: 666 recursive: bool, whether to recursively convert any functions or classes 667 that the converted function may use. 668 optional_features: converted.Feature, allows toggling optional or 669 experimental features. When set to None, only the core features are 670 enabled. 671 user_requested: bool, whether this is a function that the user explicitly 672 asked to be converted. See ConversionOptions.user_requested. 673 conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in 674 which `f` is used. 675 676 Returns: 677 Callable, a decorator that converts the given function into an equivalent 678 function that uses TensorFlow ops. 679 """ 680 681 def decorator(f): 682 """Decorator implementation.""" 683 684 def wrapper(*args, **kwargs): 685 """Wrapper that calls the converted version of f.""" 686 options = converter.ConversionOptions( 687 recursive=recursive, 688 user_requested=user_requested, 689 optional_features=optional_features) 690 try: 691 with conversion_ctx: 692 return converted_call(f, args, kwargs, options=options) 693 except Exception as e: # pylint:disable=broad-except 694 if hasattr(e, 'ag_error_metadata'): 695 raise e.ag_error_metadata.to_exception(e) 696 else: 697 raise 698 699 if inspect.isfunction(f) or inspect.ismethod(f): 700 wrapper = functools.update_wrapper(wrapper, f) 701 702 decorated_wrapper = tf_decorator.make_decorator(f, wrapper) 703 return autograph_artifact(decorated_wrapper) 704 705 return decorator 706 707 708# pylint:disable=line-too-long 709@tf_export('autograph.to_graph', v1=[]) 710def to_graph(entity, recursive=True, experimental_optional_features=None): 711 """Converts a Python entity into a TensorFlow graph. 712 713 Also see: `tf.autograph.to_code`, `tf.function`. 714 715 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts 716 Python code to TensorFlow graph code. It does not implement any caching, 717 variable management or create any actual ops, and is best used where greater 718 control over the generated TensorFlow graph is desired. Another difference 719 from `tf.function` is that `to_graph` will not wrap the graph into a 720 TensorFlow function or a Python callable. Internally, `tf.function` uses 721 `to_graph`. 722 723 Example usage: 724 725 >>> def f(x): 726 ... if x > 0: 727 ... y = x * x 728 ... else: 729 ... y = -x 730 ... return y 731 ... 732 >>> converted_f = to_graph(f) 733 >>> x = tf.constant(2) 734 >>> converted_f(x) # converted_foo is like a TensorFlow Op. 735 <tf.Tensor: shape=(), dtype=int32, numpy=4> 736 737 Supported Python entities include: 738 * functions 739 * classes 740 * object methods 741 742 Functions are converted into new functions with converted code. 743 744 Classes are converted by generating a new class whose methods use converted 745 code. 746 747 Methods are converted into unbound function that have an additional first 748 argument called `self`. 749 750 For a tutorial, see the 751 [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function). 752 For more detailed information, see the 753 [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md). 754 755 Args: 756 entity: Python callable or class to convert. 757 recursive: Whether to recursively convert any functions that the converted 758 function may call. 759 experimental_optional_features: `None`, a tuple of, or a single 760 `tf.autograph.experimental.Feature` value. 761 762 Returns: 763 Same as `entity`, the converted Python function or class. 764 765 Raises: 766 ValueError: If the entity could not be converted. 767 """ 768 try: 769 program_ctx = converter.ProgramContext( 770 options=converter.ConversionOptions( 771 recursive=recursive, 772 user_requested=True, 773 optional_features=experimental_optional_features)) 774 return autograph_artifact(_convert_actual(entity, program_ctx)) 775 except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: 776 logging.error(1, 'Error converting %s', entity, exc_info=True) 777 raise ConversionError('converting {}: {}: {}'.format( 778 entity, e.__class__.__name__, str(e))) 779 780 781@tf_export(v1=['autograph.to_graph']) 782def to_graph_v1(entity, 783 recursive=True, 784 arg_values=None, 785 arg_types=None, 786 experimental_optional_features=None): 787 """Converts a Python entity into a TensorFlow graph. 788 789 Also see: `tf.autograph.to_code`, `tf.function`. 790 791 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts 792 Python code to TensorFlow graph code. It does not implement any caching, 793 variable management or create any actual ops, and is best used where greater 794 control over the generated TensorFlow graph is desired. Another difference 795 from `tf.function` is that `to_graph` will not wrap the graph into a 796 TensorFlow function or a Python callable. Internally, `tf.function` uses 797 `to_graph`. 798 799 _Example Usage_ 800 801 ```python 802 def foo(x): 803 if x > 0: 804 y = x * x 805 else: 806 y = -x 807 return y 808 809 converted_foo = to_graph(foo) 810 811 x = tf.constant(1) 812 y = converted_foo(x) # converted_foo is a TensorFlow Op-like. 813 assert is_tensor(y) 814 ``` 815 816 Supported Python entities include: 817 * functions 818 * classes 819 * object methods 820 821 Functions are converted into new functions with converted code. 822 823 Classes are converted by generating a new class whose methods use converted 824 code. 825 826 Methods are converted into unbound function that have an additional first 827 argument called `self`. 828 829 Args: 830 entity: Python callable or class to convert. 831 recursive: Whether to recursively convert any functions that the converted 832 function may call. 833 arg_values: Deprecated. 834 arg_types: Deprecated. 835 experimental_optional_features: `None`, a tuple of, or a single 836 `tf.autograph.experimental.Feature` value. 837 838 Returns: 839 Same as `entity`, the converted Python function or class. 840 841 Raises: 842 ValueError: If the entity could not be converted. 843 """ 844 del arg_types 845 del arg_values 846 return to_graph( 847 entity, 848 recursive=recursive, 849 experimental_optional_features=experimental_optional_features) 850 851 852@tf_export(v1=['autograph.to_code']) 853def to_code_v1(entity, 854 recursive=True, 855 arg_values=None, 856 arg_types=None, 857 indentation=' ', 858 experimental_optional_features=None): 859 """Returns the source code generated by AutoGraph, as a string. 860 861 Example usage: 862 863 >>> def f(x): 864 ... if x < 0: 865 ... x = -x 866 ... return x 867 >>> tf.autograph.to_code(f) 868 "...def tf__f(x):..." 869 870 Also see: `tf.autograph.to_graph`. 871 872 Note: If a function has been decorated with `tf.function`, pass its 873 underlying Python function, rather than the callable that `tf.function 874 creates: 875 876 >>> @tf.function 877 ... def f(x): 878 ... if x < 0: 879 ... x = -x 880 ... return x 881 >>> tf.autograph.to_code(f.python_function) 882 "...def tf__f(x):..." 883 884 Args: 885 entity: Python callable or class. 886 recursive: Whether to recursively convert any functions that the converted 887 function may call. 888 arg_values: Deprecated. 889 arg_types: Deprecated. 890 indentation: Deprecated. 891 experimental_optional_features: `None`, a tuple of, or a single 892 `tf.autograph.experimental.Feature` value. 893 894 Returns: 895 The converted code as string. 896 """ 897 del arg_values 898 del arg_types 899 del indentation 900 return to_code( 901 entity, 902 recursive=recursive, 903 experimental_optional_features=experimental_optional_features) 904 905 906@tf_export('autograph.to_code', v1=[]) 907def to_code(entity, recursive=True, experimental_optional_features=None): 908 """Returns the source code generated by AutoGraph, as a string. 909 910 Example usage: 911 912 >>> def f(x): 913 ... if x < 0: 914 ... x = -x 915 ... return x 916 >>> tf.autograph.to_code(f) 917 "...def tf__f(x):..." 918 919 Also see: `tf.autograph.to_graph`. 920 921 Note: If a function has been decorated with `tf.function`, pass its 922 underlying Python function, rather than the callable that `tf.function 923 creates: 924 925 >>> @tf.function 926 ... def f(x): 927 ... if x < 0: 928 ... x = -x 929 ... return x 930 >>> tf.autograph.to_code(f.python_function) 931 "...def tf__f(x):..." 932 933 Args: 934 entity: Python callable or class to convert. 935 recursive: Whether to recursively convert any functions that the converted 936 function may call. 937 experimental_optional_features: `None`, a tuple of, or a single 938 `tf.autograph.experimental.Feature` value. 939 940 Returns: 941 The converted code as string. 942 """ 943 source = tf_inspect.getsource( 944 to_graph( 945 entity, 946 recursive=recursive, 947 experimental_optional_features=experimental_optional_features)) 948 return textwrap.dedent(source) 949 950 951_TRANSPILER = PyToTF() 952