1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module contains the user- and codegen-facing API for AutoGraph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import imp 23import inspect 24import os 25import sys 26import textwrap 27import traceback 28 29import six 30 31from tensorflow.python.autograph import operators 32from tensorflow.python.autograph import utils 33from tensorflow.python.autograph.converters import asserts 34from tensorflow.python.autograph.converters import break_statements 35from tensorflow.python.autograph.converters import call_trees 36from tensorflow.python.autograph.converters import conditional_expressions 37from tensorflow.python.autograph.converters import continue_statements 38from tensorflow.python.autograph.converters import control_flow 39from tensorflow.python.autograph.converters import directives 40from tensorflow.python.autograph.converters import functions 41from tensorflow.python.autograph.converters import lists 42from tensorflow.python.autograph.converters import logical_expressions 43from tensorflow.python.autograph.converters import return_statements 44from tensorflow.python.autograph.converters import slices 45from tensorflow.python.autograph.converters import variables 46from tensorflow.python.autograph.core import ag_ctx 47from tensorflow.python.autograph.core import converter 48from tensorflow.python.autograph.core import function_wrappers 49from tensorflow.python.autograph.core import unsupported_features_checker 50from tensorflow.python.autograph.impl import conversion 51from tensorflow.python.autograph.lang import special_functions 52from tensorflow.python.autograph.operators import py_builtins 53from tensorflow.python.autograph.pyct import anno 54from tensorflow.python.autograph.pyct import cfg 55from tensorflow.python.autograph.pyct import error_utils 56from tensorflow.python.autograph.pyct import errors 57from tensorflow.python.autograph.pyct import inspect_utils 58from tensorflow.python.autograph.pyct import origin_info 59from tensorflow.python.autograph.pyct import qual_names 60from tensorflow.python.autograph.pyct import transpiler 61from tensorflow.python.autograph.pyct.static_analysis import activity 62from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 63from tensorflow.python.autograph.utils import ag_logging as logging 64from tensorflow.python.eager import function 65from tensorflow.python.framework import errors_impl 66from tensorflow.python.util import tf_decorator 67from tensorflow.python.util import tf_inspect 68from tensorflow.python.util import tf_stack 69from tensorflow.python.util.tf_export import tf_export 70 71 72def is_autograph_strict_conversion_mode(): 73 return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0 74 75 76# 77# Error handling 78# 79 80 81# TODO(mdan): Export this symbol. 82class AutoGraphError(errors.PyCTError): 83 """Base class for all AutoGraph exceptions.""" 84 pass 85 86 87class ConversionError(AutoGraphError): 88 """Raised during the conversion process.""" 89 pass 90 91 92class StagingError(AutoGraphError): 93 """Raised during the staging (i.e. Python execution) of converted code.""" 94 pass 95 96 97class _ErrorMetadata(error_utils.ErrorMetadataBase): 98 """AutoGraph-specific error metadata. See base class.""" 99 100 def create_exception(self, source_error): 101 preferred_type = type(source_error) 102 if issubclass(preferred_type, errors_impl.OpError): 103 # Best-effort unpacking of OpError exceptions. 104 # TODO(mdan): Use a mechanism that is more future-proof. 105 init_argspec = tf_inspect.getfullargspec(preferred_type.__init__) 106 message = self.get_message() 107 init_args = tuple(init_argspec.args) 108 # At the time of this writing, TF errors either take 3 or 4 arguments, 109 # with the fourth being error_code. 110 if init_args == ('self', 'node_def', 'op', 'message', 'error_code'): 111 return preferred_type( 112 node_def=source_error.node_def, 113 op=source_error.op, 114 message=message, 115 error_code=self.error_code) 116 elif init_args == ('self', 'node_def', 'op', 'message'): 117 if 'error_code' in init_argspec.kwonlyargs: 118 return preferred_type( 119 node_def=source_error.node_def, 120 op=source_error.op, 121 message=message, 122 errro_code=self.error_code) 123 else: 124 return preferred_type( 125 node_def=source_error.node_def, 126 op=source_error.op, 127 message=message) 128 129 elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError, 130 StagingError, errors_impl.InaccessibleTensorError, 131 errors_impl.OperatorNotAllowedInGraphError): 132 return preferred_type(self.get_message()) 133 134 exc = super(_ErrorMetadata, self).create_exception(source_error) 135 if exc is not None: 136 return exc 137 138 # Note: While changing an error's message property to change the message it 139 # displays will probably work a lot of times, there is no standard way in 140 # Python to do that. The safest way is therefore to create a new exception. 141 # For user defined exceptions, we could define an interface that allowed 142 # them to work under this mechanism. 143 return StagingError(self.get_message()) 144 145 146def _attach_error_metadata(e, f): 147 """Augments an error with the metadata necessary for rewrite.""" 148 if hasattr(e, 'ag_pass_through'): 149 return 150 151 metadata = getattr(e, 'ag_error_metadata', None) 152 source_map = f.ag_source_map 153 154 if metadata is None: 155 logging.log(1, 'Caught error in user callable %s', f, exc_info=True) 156 message = '{}: {}'.format(e.__class__.__name__, e) 157 else: 158 message = None 159 160 cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] 161 162 e.ag_error_metadata = _ErrorMetadata( 163 cause_tb, metadata, message, source_map, __file__) 164 165 166class StackTraceMapper(tf_stack.StackTraceMapper): 167 """Remaps generated code to code it originated from.""" 168 169 def __init__(self, converted_fn): 170 super().__init__() 171 self._source_map = converted_fn.ag_source_map 172 # This may be called repeatedly: once on entry, by the superclass, then by 173 # each child context manager. 174 self._cached_map = None 175 176 def get_effective_source_map(self): 177 if self._cached_map is not None: 178 return self._cached_map 179 180 parent_map = self.parent.get_effective_source_map() 181 182 effective_source_map = {} 183 for loc, origin in self._source_map.items(): 184 effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename, 185 origin.loc.lineno, 186 origin.function_name) 187 188 for key, value in parent_map.items(): 189 filename, lineno, _ = value 190 value_loc = origin_info.LineLocation(filename=filename, lineno=lineno) 191 if value_loc in self._source_map: 192 origin = self._source_map[value_loc] 193 effective_source_map[key] = (origin.loc.filename, origin.loc.lineno, 194 origin.function_name) 195 else: 196 effective_source_map[key] = value 197 198 self._cached_map = effective_source_map 199 return effective_source_map 200 201 202# 203# Actual source code transformation 204# 205 206 207class PyToTF(transpiler.PyToPy): 208 """The TensorFlow AutoGraph transformer.""" 209 210 def __init__(self): 211 super(PyToTF, self).__init__() 212 self._extra_locals = None 213 214 def get_transformed_name(self, node): 215 return 'tf__' + super(PyToTF, self).get_transformed_name(node) 216 217 def get_extra_locals(self): 218 if self._extra_locals is None: 219 # TODO(mdan): Move into core or replace with an actual importable module. 220 # Craft a module that exposes the external API as well as certain 221 # internal modules. 222 ag_internal = imp.new_module('autograph') 223 ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__) 224 ag_internal.ConversionOptions = converter.ConversionOptions 225 ag_internal.STD = converter.STANDARD_OPTIONS 226 ag_internal.Feature = converter.Feature 227 ag_internal.utils = utils 228 ag_internal.FunctionScope = function_wrappers.FunctionScope 229 ag_internal.with_function_scope = function_wrappers.with_function_scope 230 # TODO(mdan): Add safeguards against name clashes. 231 # We don't want to create a submodule because we want the operators to be 232 # accessible as ag__.<operator> 233 ag_internal.__dict__.update(special_functions.__dict__) 234 ag_internal.__dict__.update(operators.__dict__) 235 236 self._extra_locals = {'ag__': ag_internal} 237 return self._extra_locals 238 239 def get_caching_key(self, ctx): 240 return ctx.options 241 242 def initial_analysis(self, node, ctx): 243 graphs = cfg.build(node) 244 node = qual_names.resolve(node) 245 node = activity.resolve(node, ctx, None) 246 node = reaching_definitions.resolve(node, ctx, graphs) 247 anno.dup( 248 node, 249 { 250 anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, 251 }, 252 ) 253 return node 254 255 def transform_ast(self, node, ctx): 256 unsupported_features_checker.verify(node) 257 node = self.initial_analysis(node, ctx) 258 259 node = functions.transform(node, ctx) 260 node = directives.transform(node, ctx) 261 node = break_statements.transform(node, ctx) 262 if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): 263 node = asserts.transform(node, ctx) 264 # Note: sequencing continue canonicalization before for loop one avoids 265 # dealing with the extra loop increment operation that the for 266 # canonicalization creates. 267 node = continue_statements.transform(node, ctx) 268 node = return_statements.transform(node, ctx) 269 if ctx.user.options.uses(converter.Feature.LISTS): 270 node = lists.transform(node, ctx) 271 node = slices.transform(node, ctx) 272 node = call_trees.transform(node, ctx) 273 node = control_flow.transform(node, ctx) 274 node = conditional_expressions.transform(node, ctx) 275 node = logical_expressions.transform(node, ctx) 276 node = variables.transform(node, ctx) 277 return node 278 279 280def _convert_actual(entity, program_ctx): 281 """Applies AutoGraph to entity.""" 282 283 # TODO(mdan): Put these extra fields inside __autograph_info__. 284 if not hasattr(entity, '__code__'): 285 raise ValueError('Cannot apply autograph to a function that doesn\'t ' 286 'expose a __code__ object. If this is a @tf.function,' 287 ' try passing f.python_function instead.') 288 289 transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx) 290 291 assert not hasattr(transformed, 'ag_module') 292 assert not hasattr(transformed, 'ag_source_map') 293 transformed.ag_module = module 294 transformed.ag_source_map = source_map 295 return transformed 296 297 298# 299# Generated code support 300# 301 302 303def autograph_artifact(entity, extras=None): 304 if inspect.ismethod(entity): 305 setattr(entity.__func__, 'autograph_info__', extras) 306 else: 307 setattr(entity, 'autograph_info__', extras) 308 return entity 309 310 311def is_autograph_artifact(entity): 312 return hasattr(entity, 'autograph_info__') 313 314 315def converted_call(f, 316 args, 317 kwargs, 318 caller_fn_scope=None, 319 options=None): 320 """Converts a function call inline. 321 322 For internal use only. 323 324 Note: The argument list is optimized for readability of generated code, which 325 may look like this: 326 327 ag__.converted_call(f, (arg1, arg2), None, fscope) 328 ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) 329 ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) 330 331 Args: 332 f: The function to convert. 333 args: Tuple, the original positional arguments of f 334 kwargs: Optional[Dict], the original keyword arguments of f 335 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function 336 scope of the converted function in which this call was originally made. 337 options: Optional[converter.ConversionOptions], conversion options. If not 338 specified, the value of caller_fn_scope.callopts is used. Either options 339 or caller_fn_scope must be present. 340 341 Returns: 342 Any, the result of executing a possibly-converted `f` with the given 343 arguments. 344 """ 345 logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, 346 kwargs) 347 348 if options is None: 349 if caller_fn_scope is None: 350 raise ValueError('either caller_fn_scope or options must have a value') 351 options = caller_fn_scope.callopts 352 353 if conversion.is_in_allowlist_cache(f, options): 354 logging.log(2, 'Allowlisted %s: from cache', f) 355 return _call_unconverted(f, args, kwargs, options, False) 356 357 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 358 logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f) 359 return _call_unconverted(f, args, kwargs, options, False) 360 361 if is_autograph_artifact(f): 362 logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) 363 return _call_unconverted(f, args, kwargs, options) 364 365 # If this is a partial, unwrap it and redo all the checks. 366 if isinstance(f, functools.partial): 367 new_kwargs = {} 368 if f.keywords is not None: 369 # Use copy to avoid mutating the underlying keywords. 370 new_kwargs = f.keywords.copy() 371 if kwargs is not None: 372 new_kwargs.update(kwargs) 373 new_args = f.args + args 374 logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, 375 new_kwargs) 376 return converted_call( 377 f.func, 378 new_args, 379 new_kwargs, 380 caller_fn_scope=caller_fn_scope, 381 options=options) 382 383 if inspect_utils.isbuiltin(f): 384 if f is eval: 385 return py_builtins.eval_in_original_context(f, args, caller_fn_scope) 386 if f is super: 387 return py_builtins.super_in_original_context(f, args, caller_fn_scope) 388 if f is globals: 389 return py_builtins.globals_in_original_context(caller_fn_scope) 390 if f is locals: 391 return py_builtins.locals_in_original_context(caller_fn_scope) 392 if kwargs: 393 return py_builtins.overload_of(f)(*args, **kwargs) 394 else: 395 return py_builtins.overload_of(f)(*args) 396 397 if conversion.is_unsupported(f): 398 return _call_unconverted(f, args, kwargs, options) 399 400 if not options.user_requested and conversion.is_allowlisted(f): 401 return _call_unconverted(f, args, kwargs, options) 402 403 # internal_convert_user_code is for example turned off when issuing a dynamic 404 # call conversion from generated code while in nonrecursive mode. In that 405 # case we evidently don't want to recurse, but we still have to convert 406 # things like builtins. 407 if not options.internal_convert_user_code: 408 return _call_unconverted(f, args, kwargs, options) 409 410 try: 411 if inspect.ismethod(f) or inspect.isfunction(f): 412 target_entity = f 413 effective_args = args 414 415 f_self = getattr(f, '__self__', None) 416 if f_self is not None: 417 if isinstance(f_self, function.TfMethodTarget): 418 f_self = f_self.target 419 effective_args = (f_self,) + effective_args 420 421 elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): 422 # Callable objects. Dunder methods have special lookup rules, see: 423 # https://docs.python.org/3/reference/datamodel.html#specialnames 424 # TODO(mdan): Recurse into converted_call to simplify other verifications. 425 # This should be handled in the same way as partials. 426 target_entity = f.__class__.__call__ 427 effective_args = (f,) + args 428 429 else: 430 target_entity = f 431 raise NotImplementedError('unknown callable type "%s"' % type(f)) 432 433 except Exception as e: # pylint:disable=broad-except 434 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) 435 if is_autograph_strict_conversion_mode(): 436 raise 437 return _fall_back_unconverted(f, args, kwargs, options, e) 438 439 if not hasattr(target_entity, '__code__'): 440 logging.log(2, 'Permanently allowed: %s: native binding', 441 target_entity) 442 return _call_unconverted(f, args, kwargs, options) 443 elif (hasattr(target_entity.__code__, 'co_filename') and 444 target_entity.__code__.co_filename == '<string>'): 445 # TODO(mdan): __globals__['txt'] might work in Py3. 446 logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)', 447 target_entity) 448 return _call_unconverted(f, args, kwargs, options) 449 450 try: 451 program_ctx = converter.ProgramContext(options=options) 452 converted_f = _convert_actual(target_entity, program_ctx) 453 if logging.has_verbosity(2): 454 _log_callargs(converted_f, effective_args, kwargs) 455 except Exception as e: # pylint:disable=broad-except 456 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) 457 if is_autograph_strict_conversion_mode(): 458 raise 459 return _fall_back_unconverted(f, args, kwargs, options, e) 460 461 with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter(): 462 try: 463 if kwargs is not None: 464 result = converted_f(*effective_args, **kwargs) 465 else: 466 result = converted_f(*effective_args) 467 except Exception as e: 468 _attach_error_metadata(e, converted_f) 469 raise 470 471 return result 472 473 474def _call_unconverted(f, args, kwargs, options, update_cache=True): 475 """Calls the original function without converting with AutoGraph.""" 476 if update_cache: 477 conversion.cache_allowlisted(f, options) 478 479 if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget): 480 return f.__self__.call(args, kwargs) 481 482 if kwargs is not None: 483 return f(*args, **kwargs) 484 return f(*args) 485 486 487def _fall_back_unconverted(f, args, kwargs, options, exc): 488 """Falls back to calling the function unconverted, in case of error.""" 489 # TODO(mdan): Consider adding an internal metric. 490 warning_template = ( 491 'AutoGraph could not transform %s and will run it as-is.\n' 492 '%s' 493 'Cause: %s\n' 494 'To silence this warning, decorate the function with' 495 ' @tf.autograph.experimental.do_not_convert') 496 if isinstance(exc, errors.UnsupportedLanguageElementError): 497 if not conversion.is_in_allowlist_cache(f, options): 498 logging.warn(warning_template, f, '', exc) 499 else: 500 file_bug_message = ( 501 'Please report this to the TensorFlow team. When filing the bug, set' 502 ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' 503 ' attach the full output.\n') 504 logging.warn(warning_template, f, file_bug_message, exc) 505 506 return _call_unconverted(f, args, kwargs, options) 507 508 509# 510# TensorFlow integration 511# 512 513 514@tf_export('__internal__.autograph.tf_convert', v1=[]) 515def tf_convert(f, ctx, convert_by_default=True, user_requested=False): 516 """Decorator that applies AutoGraph to a function. 517 518 Use in internal APIs. 519 520 This API is suitable for high order functions internal to the TensorFlow API, 521 and more generally any function to which AutoGraph is not applied. 522 523 Guidance: `convert` was a decorator meant for use directly by developers, but 524 most of today's uses go through `tf.function`. `tf_convert` is to be called 525 from high order functions internal to TF. By default, all the internal 526 TensorFlow functions are skipped when AutoGraph processes the code. This may 527 lead to user-supplied functions to be incorrectly skipped as well. 528 `tf_convert` helps avoid that. See the following example for more details. 529 530 ``` 531 =====tf_internal_module.py===== 532 533 def unconverted(input_fn): 534 return input_fn() 535 536 def converted(input_fn): 537 return tf.__internal__.autograph.tf_convert( 538 input_fn, ctx=tf.__internal__.autograph.control_status_ctx())() 539 540 ======user_module.py====== 541 542 @tf.function 543 def foo(input_fn) 544 return unconverted(input_fn) 545 546 @tf.function 547 def bar(input_fn) 548 return converted(input_fn) 549 550 @tf.function(autograph=False) 551 def baz(input_fn) 552 return converted(input_fn) 553 ``` 554 555 The `foo` method above will execute the `input_fn` without autograph 556 conversion, while the `bar` method will run an autographed `input_fn`. The 557 `baz` method will run an unconverted `input_fn`, since `tf_convert` respect 558 the control status context. 559 560 Note that both methods in `tf_internal_module` are skipped by autograph when 561 tracing the `tf.function`. The configuration of whether a module/package 562 should be skipped by autograph is controlled in 563 tensorflow/python/autograph/core/config.py. 564 565 Args: 566 f: Callable. 567 ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. 568 convert_by_default: bool, whether to use AutoGraph when the context doesn't 569 specify. 570 user_requested: bool, whether to ignore the conversion allowlist. See 571 ConversionOptions.user_requested. 572 573 Returns: 574 Either `f or the converted version of `f`. 575 """ 576 577 if is_autograph_artifact(f): 578 return f 579 f_wrapper = f 580 decorators, f = tf_decorator.unwrap(f) 581 582 # TODO(mdan): Grab features from context. 583 # Note: we pass the original context through to convert to properly handle the 584 # following scenario, which can be used inside TF implementations: 585 # 586 # ctx = ag_ctx.control_status_ctx() 587 # @function(autograph=False) # Low-level graph code 588 # def inner_fn(): 589 # # The context is disabled here, but should be enabled in user user_fn 590 # tf_convert(user_fn, ctx=ctx) 591 if ctx.status == ag_ctx.Status.ENABLED: 592 wrapper_factory = convert( 593 recursive=True, user_requested=user_requested, conversion_ctx=ctx) 594 elif ctx.status == ag_ctx.Status.DISABLED: 595 wrapper_factory = do_not_convert 596 elif ctx.status == ag_ctx.Status.UNSPECIFIED: 597 if convert_by_default: 598 wrapper_factory = convert( 599 recursive=True, user_requested=user_requested, conversion_ctx=ctx) 600 else: 601 wrapper_factory = call_with_unspecified_conversion_status 602 else: 603 assert False, 'This switch contains all possible cases!' 604 wrapper = wrapper_factory(f) 605 606 if decorators: 607 wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper) 608 609 return autograph_artifact(wrapper) 610 611 612def call_with_unspecified_conversion_status(func): 613 """Decorator that resets the conversion context to the unspecified status.""" 614 def wrapper(*args, **kwargs): 615 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): 616 return func(*args, **kwargs) 617 618 if inspect.isfunction(func) or inspect.ismethod(func): 619 wrapper = functools.update_wrapper(wrapper, func) 620 621 return autograph_artifact(wrapper) 622 623 624def _log_callargs(f, args, kwargs): 625 """Logging helper.""" 626 logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) 627 if not six.PY2: 628 logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) 629 630 if kwargs is not None: 631 callargs = tf_inspect.getcallargs(f, *args, **kwargs) 632 else: 633 callargs = tf_inspect.getcallargs(f, *args) 634 635 formatted_callargs = '\n'.join( 636 ' {}: {}'.format(k, v) for k, v in callargs.items()) 637 logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) 638 639 640# 641# Public API 642# 643 644 645@tf_export('autograph.experimental.do_not_convert') 646def do_not_convert(func=None): 647 """Decorator that suppresses the conversion of a function. 648 649 Args: 650 func: function to decorate. 651 652 Returns: 653 If `func` is not None, returns a `Callable` which is equivalent to 654 `func`, but is not converted by AutoGraph. 655 If `func` is None, returns a decorator that, when invoked with a 656 single `func` argument, returns a `Callable` equivalent to the 657 above case. 658 """ 659 if func is None: 660 return do_not_convert 661 662 def wrapper(*args, **kwargs): 663 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): 664 return func(*args, **kwargs) 665 666 if inspect.isfunction(func) or inspect.ismethod(func): 667 wrapper = functools.update_wrapper(wrapper, func) 668 669 return autograph_artifact(wrapper) 670 671 672# TODO(mdan): Make private. 673def convert(recursive=False, 674 optional_features=None, 675 user_requested=True, 676 conversion_ctx=ag_ctx.NullCtx()): 677 """Decorator that compiles a function to use TensorFlow ops. 678 679 The decorator is dynamic - it recompiles the target whenever the decorated 680 function is called. This means the parameter values are known at conversion. 681 It also means that repeated calls with different types of parameters will be 682 correctly processed. 683 684 Args: 685 recursive: bool, whether to recursively convert any functions or classes 686 that the converted function may use. 687 optional_features: converted.Feature, allows toggling optional or 688 experimental features. When set to None, only the core features are 689 enabled. 690 user_requested: bool, whether this is a function that the user explicitly 691 asked to be converted. See ConversionOptions.user_requested. 692 conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in 693 which `f` is used. 694 695 Returns: 696 Callable, a decorator that converts the given function into an equivalent 697 function that uses TensorFlow ops. 698 """ 699 700 def decorator(f): 701 """Decorator implementation.""" 702 703 def wrapper(*args, **kwargs): 704 """Wrapper that calls the converted version of f.""" 705 options = converter.ConversionOptions( 706 recursive=recursive, 707 user_requested=user_requested, 708 optional_features=optional_features) 709 try: 710 with conversion_ctx: 711 return converted_call(f, args, kwargs, options=options) 712 except Exception as e: # pylint:disable=broad-except 713 if hasattr(e, 'ag_error_metadata'): 714 raise e.ag_error_metadata.to_exception(e) 715 else: 716 raise 717 718 if inspect.isfunction(f) or inspect.ismethod(f): 719 wrapper = functools.update_wrapper(wrapper, f) 720 721 decorated_wrapper = tf_decorator.make_decorator(f, wrapper) 722 return autograph_artifact(decorated_wrapper) 723 724 return decorator 725 726 727# pylint:disable=line-too-long 728@tf_export('autograph.to_graph', v1=[]) 729def to_graph(entity, recursive=True, experimental_optional_features=None): 730 """Converts a Python entity into a TensorFlow graph. 731 732 Also see: `tf.autograph.to_code`, `tf.function`. 733 734 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts 735 Python code to TensorFlow graph code. It does not implement any caching, 736 variable management or create any actual ops, and is best used where greater 737 control over the generated TensorFlow graph is desired. Another difference 738 from `tf.function` is that `to_graph` will not wrap the graph into a 739 TensorFlow function or a Python callable. Internally, `tf.function` uses 740 `to_graph`. 741 742 Example usage: 743 744 >>> def f(x): 745 ... if x > 0: 746 ... y = x * x 747 ... else: 748 ... y = -x 749 ... return y 750 ... 751 >>> converted_f = to_graph(f) 752 >>> x = tf.constant(2) 753 >>> converted_f(x) # converted_foo is like a TensorFlow Op. 754 <tf.Tensor: shape=(), dtype=int32, numpy=4> 755 756 Supported Python entities include: 757 * functions 758 * classes 759 * object methods 760 761 Functions are converted into new functions with converted code. 762 763 Classes are converted by generating a new class whose methods use converted 764 code. 765 766 Methods are converted into unbound function that have an additional first 767 argument called `self`. 768 769 For a tutorial, see the 770 [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function). 771 For more detailed information, see the 772 [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md). 773 774 Args: 775 entity: Python callable or class to convert. 776 recursive: Whether to recursively convert any functions that the converted 777 function may call. 778 experimental_optional_features: `None`, a tuple of, or a single 779 `tf.autograph.experimental.Feature` value. 780 781 Returns: 782 Same as `entity`, the converted Python function or class. 783 784 Raises: 785 ValueError: If the entity could not be converted. 786 """ 787 try: 788 program_ctx = converter.ProgramContext( 789 options=converter.ConversionOptions( 790 recursive=recursive, 791 user_requested=True, 792 optional_features=experimental_optional_features)) 793 return autograph_artifact(_convert_actual(entity, program_ctx)) 794 except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: 795 logging.error(1, 'Error converting %s', entity, exc_info=True) 796 raise ConversionError('converting {}: {}: {}'.format( 797 entity, e.__class__.__name__, str(e))) 798 799 800@tf_export(v1=['autograph.to_graph']) 801def to_graph_v1(entity, 802 recursive=True, 803 arg_values=None, 804 arg_types=None, 805 experimental_optional_features=None): 806 """Converts a Python entity into a TensorFlow graph. 807 808 Also see: `tf.autograph.to_code`, `tf.function`. 809 810 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts 811 Python code to TensorFlow graph code. It does not implement any caching, 812 variable management or create any actual ops, and is best used where greater 813 control over the generated TensorFlow graph is desired. Another difference 814 from `tf.function` is that `to_graph` will not wrap the graph into a 815 TensorFlow function or a Python callable. Internally, `tf.function` uses 816 `to_graph`. 817 818 _Example Usage_ 819 820 ```python 821 def foo(x): 822 if x > 0: 823 y = x * x 824 else: 825 y = -x 826 return y 827 828 converted_foo = to_graph(foo) 829 830 x = tf.constant(1) 831 y = converted_foo(x) # converted_foo is a TensorFlow Op-like. 832 assert is_tensor(y) 833 ``` 834 835 Supported Python entities include: 836 * functions 837 * classes 838 * object methods 839 840 Functions are converted into new functions with converted code. 841 842 Classes are converted by generating a new class whose methods use converted 843 code. 844 845 Methods are converted into unbound function that have an additional first 846 argument called `self`. 847 848 Args: 849 entity: Python callable or class to convert. 850 recursive: Whether to recursively convert any functions that the converted 851 function may call. 852 arg_values: Deprecated. 853 arg_types: Deprecated. 854 experimental_optional_features: `None`, a tuple of, or a single 855 `tf.autograph.experimental.Feature` value. 856 857 Returns: 858 Same as `entity`, the converted Python function or class. 859 860 Raises: 861 ValueError: If the entity could not be converted. 862 """ 863 del arg_types 864 del arg_values 865 return to_graph( 866 entity, 867 recursive=recursive, 868 experimental_optional_features=experimental_optional_features) 869 870 871@tf_export(v1=['autograph.to_code']) 872def to_code_v1(entity, 873 recursive=True, 874 arg_values=None, 875 arg_types=None, 876 indentation=' ', 877 experimental_optional_features=None): 878 """Returns the source code generated by AutoGraph, as a string. 879 880 Example usage: 881 882 >>> def f(x): 883 ... if x < 0: 884 ... x = -x 885 ... return x 886 >>> tf.autograph.to_code(f) 887 "...def tf__f(x):..." 888 889 Also see: `tf.autograph.to_graph`. 890 891 Note: If a function has been decorated with `tf.function`, pass its 892 underlying Python function, rather than the callable that `tf.function 893 creates: 894 895 >>> @tf.function 896 ... def f(x): 897 ... if x < 0: 898 ... x = -x 899 ... return x 900 >>> tf.autograph.to_code(f.python_function) 901 "...def tf__f(x):..." 902 903 Args: 904 entity: Python callable or class. 905 recursive: Whether to recursively convert any functions that the converted 906 function may call. 907 arg_values: Deprecated. 908 arg_types: Deprecated. 909 indentation: Deprecated. 910 experimental_optional_features: `None`, a tuple of, or a single 911 `tf.autograph.experimental.Feature` value. 912 913 Returns: 914 The converted code as string. 915 """ 916 del arg_values 917 del arg_types 918 del indentation 919 return to_code( 920 entity, 921 recursive=recursive, 922 experimental_optional_features=experimental_optional_features) 923 924 925@tf_export('autograph.to_code', v1=[]) 926def to_code(entity, recursive=True, experimental_optional_features=None): 927 """Returns the source code generated by AutoGraph, as a string. 928 929 Example usage: 930 931 >>> def f(x): 932 ... if x < 0: 933 ... x = -x 934 ... return x 935 >>> tf.autograph.to_code(f) 936 "...def tf__f(x):..." 937 938 Also see: `tf.autograph.to_graph`. 939 940 Note: If a function has been decorated with `tf.function`, pass its 941 underlying Python function, rather than the callable that `tf.function 942 creates: 943 944 >>> @tf.function 945 ... def f(x): 946 ... if x < 0: 947 ... x = -x 948 ... return x 949 >>> tf.autograph.to_code(f.python_function) 950 "...def tf__f(x):..." 951 952 Args: 953 entity: Python callable or class to convert. 954 recursive: Whether to recursively convert any functions that the converted 955 function may call. 956 experimental_optional_features: `None`, a tuple of, or a single 957 `tf.autograph.experimental.Feature` value. 958 959 Returns: 960 The converted code as string. 961 """ 962 source = tf_inspect.getsource( 963 to_graph( 964 entity, 965 recursive=recursive, 966 experimental_optional_features=experimental_optional_features)) 967 return textwrap.dedent(source) 968 969 970_TRANSPILER = PyToTF() 971