• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module contains the user- and codegen-facing API for AutoGraph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import imp
23import inspect
24import os
25import sys
26import textwrap
27import traceback
28
29import six
30
31from tensorflow.python.autograph import operators
32from tensorflow.python.autograph import utils
33from tensorflow.python.autograph.converters import asserts
34from tensorflow.python.autograph.converters import break_statements
35from tensorflow.python.autograph.converters import call_trees
36from tensorflow.python.autograph.converters import conditional_expressions
37from tensorflow.python.autograph.converters import continue_statements
38from tensorflow.python.autograph.converters import control_flow
39from tensorflow.python.autograph.converters import directives
40from tensorflow.python.autograph.converters import functions
41from tensorflow.python.autograph.converters import lists
42from tensorflow.python.autograph.converters import logical_expressions
43from tensorflow.python.autograph.converters import return_statements
44from tensorflow.python.autograph.converters import slices
45from tensorflow.python.autograph.converters import variables
46from tensorflow.python.autograph.core import ag_ctx
47from tensorflow.python.autograph.core import converter
48from tensorflow.python.autograph.core import function_wrappers
49from tensorflow.python.autograph.core import unsupported_features_checker
50from tensorflow.python.autograph.impl import conversion
51from tensorflow.python.autograph.lang import special_functions
52from tensorflow.python.autograph.operators import py_builtins
53from tensorflow.python.autograph.pyct import anno
54from tensorflow.python.autograph.pyct import cfg
55from tensorflow.python.autograph.pyct import error_utils
56from tensorflow.python.autograph.pyct import errors
57from tensorflow.python.autograph.pyct import inspect_utils
58from tensorflow.python.autograph.pyct import origin_info
59from tensorflow.python.autograph.pyct import qual_names
60from tensorflow.python.autograph.pyct import transpiler
61from tensorflow.python.autograph.pyct.static_analysis import activity
62from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
63from tensorflow.python.autograph.utils import ag_logging as logging
64from tensorflow.python.eager import function
65from tensorflow.python.framework import errors_impl
66from tensorflow.python.util import tf_decorator
67from tensorflow.python.util import tf_inspect
68from tensorflow.python.util import tf_stack
69from tensorflow.python.util.tf_export import tf_export
70
71
72def is_autograph_strict_conversion_mode():
73  return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0
74
75
76#
77# Error handling
78#
79
80
81# TODO(mdan): Export this symbol.
82class AutoGraphError(errors.PyCTError):
83  """Base class for all AutoGraph exceptions."""
84  pass
85
86
87class ConversionError(AutoGraphError):
88  """Raised during the conversion process."""
89  pass
90
91
92class StagingError(AutoGraphError):
93  """Raised during the staging (i.e. Python execution) of converted code."""
94  pass
95
96
97class _ErrorMetadata(error_utils.ErrorMetadataBase):
98  """AutoGraph-specific error metadata. See base class."""
99
100  def create_exception(self, source_error):
101    preferred_type = type(source_error)
102    if issubclass(preferred_type, errors_impl.OpError):
103      # Best-effort unpacking of OpError exceptions.
104      # TODO(mdan): Use a mechanism that is more future-proof.
105      init_argspec = tf_inspect.getfullargspec(preferred_type.__init__)
106      message = self.get_message()
107      init_args = tuple(init_argspec.args)
108      # At the time of this writing, TF errors either take 3 or 4 arguments,
109      # the argument '*args' may or may not be used.
110      if init_args == ('self', 'node_def', 'op', 'message'):
111        return preferred_type(source_error.node_def, source_error.op, message,
112                              source_error.experimental_payloads)
113
114    elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError,
115                            StagingError, errors_impl.InaccessibleTensorError,
116                            errors_impl.OperatorNotAllowedInGraphError):
117      return preferred_type(self.get_message())
118
119    exc = super(_ErrorMetadata, self).create_exception(source_error)
120    if exc is not None:
121      return exc
122
123    # Note: While changing an error's message property to change the message it
124    # displays will probably work a lot of times, there is no standard way in
125    # Python to do that. The safest way is therefore to create a new exception.
126    # For user defined exceptions, we could define an interface that allowed
127    # them to work under this mechanism.
128    return StagingError(self.get_message())
129
130
131def _attach_error_metadata(e, f):
132  """Augments an error with the metadata necessary for rewrite."""
133  if hasattr(e, 'ag_pass_through'):
134    return
135
136  metadata = getattr(e, 'ag_error_metadata', None)
137  source_map = f.ag_source_map
138
139  if metadata is None:
140    logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
141    message = '{}: {}'.format(e.__class__.__name__, e)
142  else:
143    message = None
144
145  cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
146
147  e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map,
148                                       __file__)
149
150
151class StackTraceMapper(tf_stack.StackTraceMapper):
152  """Remaps generated code to code it originated from."""
153
154  def __init__(self, converted_fn):
155    super().__init__()
156    self._source_map = converted_fn.ag_source_map
157    # This may be called repeatedly: once on entry, by the superclass, then by
158    # each child context manager.
159    self._cached_map = None
160
161  def get_effective_source_map(self):
162    if self._cached_map is not None:
163      return self._cached_map
164
165    parent_map = self.parent.get_effective_source_map()
166
167    effective_source_map = {}
168    for loc, origin in self._source_map.items():
169      effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename,
170                                                          origin.loc.lineno,
171                                                          origin.function_name)
172
173    for key, value in parent_map.items():
174      filename, lineno, _ = value
175      value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
176      if value_loc in self._source_map:
177        origin = self._source_map[value_loc]
178        effective_source_map[key] = (origin.loc.filename, origin.loc.lineno,
179                                     origin.function_name)
180      else:
181        effective_source_map[key] = value
182
183    self._cached_map = effective_source_map
184    return effective_source_map
185
186
187#
188# Actual source code transformation
189#
190
191
192class PyToTF(transpiler.PyToPy):
193  """The TensorFlow AutoGraph transformer."""
194
195  def __init__(self):
196    super(PyToTF, self).__init__()
197    self._extra_locals = None
198
199  def get_transformed_name(self, node):
200    return 'tf__' + super(PyToTF, self).get_transformed_name(node)
201
202  def get_extra_locals(self):
203    if self._extra_locals is None:
204      # TODO(mdan): Move into core or replace with an actual importable module.
205      # Craft a module that exposes the external API as well as certain
206      # internal modules.
207      ag_internal = imp.new_module('autograph')
208      ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
209      ag_internal.ConversionOptions = converter.ConversionOptions
210      ag_internal.STD = converter.STANDARD_OPTIONS
211      ag_internal.Feature = converter.Feature
212      ag_internal.utils = utils
213      ag_internal.FunctionScope = function_wrappers.FunctionScope
214      ag_internal.with_function_scope = function_wrappers.with_function_scope
215      # TODO(mdan): Add safeguards against name clashes.
216      # We don't want to create a submodule because we want the operators to be
217      # accessible as ag__.<operator>
218      ag_internal.__dict__.update(special_functions.__dict__)
219      ag_internal.__dict__.update(operators.__dict__)
220
221      self._extra_locals = {'ag__': ag_internal}
222    return self._extra_locals
223
224  def get_caching_key(self, ctx):
225    return ctx.options
226
227  def initial_analysis(self, node, ctx):
228    graphs = cfg.build(node)
229    node = qual_names.resolve(node)
230    node = activity.resolve(node, ctx, None)
231    node = reaching_definitions.resolve(node, ctx, graphs)
232    anno.dup(
233        node,
234        {
235            anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
236        },
237    )
238    return node
239
240  def transform_ast(self, node, ctx):
241    unsupported_features_checker.verify(node)
242    node = self.initial_analysis(node, ctx)
243
244    node = functions.transform(node, ctx)
245    node = directives.transform(node, ctx)
246    node = break_statements.transform(node, ctx)
247    if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
248      node = asserts.transform(node, ctx)
249    # Note: sequencing continue canonicalization before for loop one avoids
250    # dealing with the extra loop increment operation that the for
251    # canonicalization creates.
252    node = continue_statements.transform(node, ctx)
253    node = return_statements.transform(node, ctx)
254    if ctx.user.options.uses(converter.Feature.LISTS):
255      node = lists.transform(node, ctx)
256      node = slices.transform(node, ctx)
257    node = call_trees.transform(node, ctx)
258    node = control_flow.transform(node, ctx)
259    node = conditional_expressions.transform(node, ctx)
260    node = logical_expressions.transform(node, ctx)
261    node = variables.transform(node, ctx)
262    return node
263
264
265def _convert_actual(entity, program_ctx):
266  """Applies AutoGraph to entity."""
267
268  # TODO(mdan): Put these extra fields inside __autograph_info__.
269  if not hasattr(entity, '__code__'):
270    raise ValueError('Cannot apply autograph to a function that doesn\'t '
271                     'expose a __code__ object. If this is a @tf.function,'
272                     ' try passing f.python_function instead.')
273
274  transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
275
276  assert not hasattr(transformed, 'ag_module')
277  assert not hasattr(transformed, 'ag_source_map')
278  transformed.ag_module = module
279  transformed.ag_source_map = source_map
280  return transformed
281
282
283#
284# Generated code support
285#
286
287
288def autograph_artifact(entity, extras=None):
289  if inspect.ismethod(entity):
290    setattr(entity.__func__, 'autograph_info__', extras)
291  else:
292    setattr(entity, 'autograph_info__', extras)
293  return entity
294
295
296def is_autograph_artifact(entity):
297  return hasattr(entity, 'autograph_info__')
298
299
300def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
301  """Converts a function call inline.
302
303  For internal use only.
304
305  Note: The argument list is optimized for readability of generated code, which
306  may look like this:
307
308    ag__.converted_call(f, (arg1, arg2), None, fscope)
309    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
310    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)
311
312  Args:
313    f: The function to convert.
314    args: Tuple, the original positional arguments of f
315    kwargs: Optional[Dict], the original keyword arguments of f
316    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
317      scope of the converted function in which this call was originally made.
318    options: Optional[converter.ConversionOptions], conversion options. If not
319      specified, the value of caller_fn_scope.callopts is used. Either options
320      or caller_fn_scope must be present.
321
322  Returns:
323    Any, the result of executing a possibly-converted `f` with the given
324      arguments.
325  """
326  logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f, args,
327              kwargs)
328
329  if options is None:
330    if caller_fn_scope is None:
331      raise ValueError('either caller_fn_scope or options must have a value')
332    options = caller_fn_scope.callopts
333
334  if conversion.is_in_allowlist_cache(f, options):
335    logging.log(2, 'Allowlisted %s: from cache', f)
336    return _call_unconverted(f, args, kwargs, options, False)
337
338  if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
339    logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
340    return _call_unconverted(f, args, kwargs, options, False)
341
342  if is_autograph_artifact(f):
343    logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f)
344    return _call_unconverted(f, args, kwargs, options)
345
346  # If this is a partial, unwrap it and redo all the checks.
347  if isinstance(f, functools.partial):
348    new_kwargs = {}
349    if f.keywords is not None:
350      # Use copy to avoid mutating the underlying keywords.
351      new_kwargs = f.keywords.copy()
352    if kwargs is not None:
353      new_kwargs.update(kwargs)
354    new_args = f.args + args
355    logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
356                new_kwargs)
357    return converted_call(
358        f.func,
359        new_args,
360        new_kwargs,
361        caller_fn_scope=caller_fn_scope,
362        options=options)
363
364  if inspect_utils.isbuiltin(f):
365    if f is eval:
366      return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
367    if f is super:
368      return py_builtins.super_in_original_context(f, args, caller_fn_scope)
369    if f is globals:
370      return py_builtins.globals_in_original_context(caller_fn_scope)
371    if f is locals:
372      return py_builtins.locals_in_original_context(caller_fn_scope)
373    if kwargs:
374      return py_builtins.overload_of(f)(*args, **kwargs)
375    else:
376      return py_builtins.overload_of(f)(*args)
377
378  if conversion.is_unsupported(f):
379    return _call_unconverted(f, args, kwargs, options)
380
381  if not options.user_requested and conversion.is_allowlisted(f):
382    return _call_unconverted(f, args, kwargs, options)
383
384  # internal_convert_user_code is for example turned off when issuing a dynamic
385  # call conversion from generated code while in nonrecursive mode. In that
386  # case we evidently don't want to recurse, but we still have to convert
387  # things like builtins.
388  if not options.internal_convert_user_code:
389    return _call_unconverted(f, args, kwargs, options)
390
391  try:
392    if inspect.ismethod(f) or inspect.isfunction(f):
393      target_entity = f
394      effective_args = args
395
396      f_self = getattr(f, '__self__', None)
397      if f_self is not None:
398        if isinstance(f_self, function.TfMethodTarget):
399          f_self = f_self.target
400        effective_args = (f_self,) + effective_args
401
402    elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
403      # Callable objects. Dunder methods have special lookup rules, see:
404      # https://docs.python.org/3/reference/datamodel.html#specialnames
405      # TODO(mdan): Recurse into converted_call to simplify other verifications.
406      # This should be handled in the same way as partials.
407      target_entity = f.__class__.__call__
408      effective_args = (f,) + args
409
410    else:
411      target_entity = f
412      raise NotImplementedError('unknown callable type "%s"' % type(f))
413
414  except Exception as e:  # pylint:disable=broad-except
415    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
416    if is_autograph_strict_conversion_mode():
417      raise
418    return _fall_back_unconverted(f, args, kwargs, options, e)
419
420  if not hasattr(target_entity, '__code__'):
421    logging.log(2, 'Permanently allowed: %s: native binding', target_entity)
422    return _call_unconverted(f, args, kwargs, options)
423  elif (hasattr(target_entity.__code__, 'co_filename') and
424        target_entity.__code__.co_filename == '<string>'):
425    # TODO(mdan): __globals__['txt'] might work in Py3.
426    logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)',
427                target_entity)
428    return _call_unconverted(f, args, kwargs, options)
429
430  try:
431    program_ctx = converter.ProgramContext(options=options)
432    converted_f = _convert_actual(target_entity, program_ctx)
433    if logging.has_verbosity(2):
434      _log_callargs(converted_f, effective_args, kwargs)
435  except Exception as e:  # pylint:disable=broad-except
436    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
437    if is_autograph_strict_conversion_mode():
438      raise
439    return _fall_back_unconverted(f, args, kwargs, options, e)
440
441  with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
442    try:
443      if kwargs is not None:
444        result = converted_f(*effective_args, **kwargs)
445      else:
446        result = converted_f(*effective_args)
447    except Exception as e:
448      _attach_error_metadata(e, converted_f)
449      raise
450
451  return result
452
453
454def _call_unconverted(f, args, kwargs, options, update_cache=True):
455  """Calls the original function without converting with AutoGraph."""
456  if update_cache:
457    conversion.cache_allowlisted(f, options)
458
459  if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
460    return f.__self__.call(args, kwargs)
461
462  if kwargs is not None:
463    return f(*args, **kwargs)
464  return f(*args)
465
466
467def _fall_back_unconverted(f, args, kwargs, options, exc):
468  """Falls back to calling the function unconverted, in case of error."""
469  # TODO(mdan): Consider adding an internal metric.
470  warning_template = (
471      'AutoGraph could not transform %s and will run it as-is.\n'
472      '%s'
473      'Cause: %s\n'
474      'To silence this warning, decorate the function with'
475      ' @tf.autograph.experimental.do_not_convert')
476  if isinstance(exc, errors.UnsupportedLanguageElementError):
477    if not conversion.is_in_allowlist_cache(f, options):
478      logging.warn(warning_template, f, '', exc)
479  else:
480    file_bug_message = (
481        'Please report this to the TensorFlow team. When filing the bug, set'
482        ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
483        ' attach the full output.\n')
484    logging.warn(warning_template, f, file_bug_message, exc)
485
486  return _call_unconverted(f, args, kwargs, options)
487
488
489#
490# TensorFlow integration
491#
492
493
494@tf_export('__internal__.autograph.tf_convert', v1=[])
495def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
496  """Decorator that applies AutoGraph to a function.
497
498  Use in internal APIs.
499
500  This API is suitable for high order functions internal to the TensorFlow API,
501  and more generally any function to which AutoGraph is not applied.
502
503  Guidance: `convert` was a decorator meant for use directly by developers, but
504  most of today's uses go through `tf.function`. `tf_convert` is to be called
505  from high order functions internal to TF. By default, all the internal
506  TensorFlow functions are skipped when AutoGraph processes the code. This may
507  lead to user-supplied functions to be incorrectly skipped as well.
508  `tf_convert` helps avoid that. See the following example for more details.
509
510  ```
511  =====tf_internal_module.py=====
512
513  def unconverted(input_fn):
514    return input_fn()
515
516  def converted(input_fn):
517    return tf.__internal__.autograph.tf_convert(
518       input_fn, ctx=tf.__internal__.autograph.control_status_ctx())()
519
520  ======user_module.py======
521
522  @tf.function
523  def foo(input_fn)
524    return unconverted(input_fn)
525
526  @tf.function
527  def bar(input_fn)
528    return converted(input_fn)
529
530  @tf.function(autograph=False)
531  def baz(input_fn)
532    return converted(input_fn)
533  ```
534
535  The `foo` method above will execute the `input_fn` without autograph
536  conversion, while the `bar` method will run an autographed `input_fn`. The
537  `baz` method will run an unconverted `input_fn`, since `tf_convert` respect
538  the control status context.
539
540  Note that both methods in `tf_internal_module` are skipped by autograph when
541  tracing the `tf.function`. The configuration of whether a module/package
542  should be skipped by autograph is controlled in
543  tensorflow/python/autograph/core/config.py.
544
545  Args:
546    f: Callable.
547    ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
548    convert_by_default: bool, whether to use AutoGraph when the context doesn't
549      specify.
550    user_requested: bool, whether to ignore the conversion allowlist. See
551      ConversionOptions.user_requested.
552
553  Returns:
554    Either `f or the converted version of `f`.
555  """
556
557  if is_autograph_artifact(f):
558    return f
559  f_wrapper = f
560  decorators, f = tf_decorator.unwrap(f)
561
562  # TODO(mdan): Grab features from context.
563  # Note: we pass the original context through to convert to properly handle the
564  # following scenario, which can be used inside TF implementations:
565  #
566  #   ctx = ag_ctx.control_status_ctx()
567  #   @function(autograph=False)  # Low-level graph code
568  #   def inner_fn():
569  #     # The context is disabled here, but should be enabled in user user_fn
570  #     tf_convert(user_fn, ctx=ctx)
571  if ctx.status == ag_ctx.Status.ENABLED:
572    wrapper_factory = convert(
573        recursive=True, user_requested=user_requested, conversion_ctx=ctx)
574  elif ctx.status == ag_ctx.Status.DISABLED:
575    wrapper_factory = do_not_convert
576  elif ctx.status == ag_ctx.Status.UNSPECIFIED:
577    if convert_by_default:
578      wrapper_factory = convert(
579          recursive=True, user_requested=user_requested, conversion_ctx=ctx)
580    else:
581      wrapper_factory = call_with_unspecified_conversion_status
582  else:
583    assert False, 'This switch contains all possible cases!'
584  wrapper = wrapper_factory(f)
585
586  if decorators:
587    wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
588
589  return autograph_artifact(wrapper)
590
591
592def call_with_unspecified_conversion_status(func):
593  """Decorator that resets the conversion context to the unspecified status."""
594
595  def wrapper(*args, **kwargs):
596    with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
597      return func(*args, **kwargs)
598
599  if inspect.isfunction(func) or inspect.ismethod(func):
600    wrapper = functools.update_wrapper(wrapper, func)
601
602  return autograph_artifact(wrapper)
603
604
605def _log_callargs(f, args, kwargs):
606  """Logging helper."""
607  logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
608  if not six.PY2:
609    logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
610
611  if kwargs is not None:
612    callargs = tf_inspect.getcallargs(f, *args, **kwargs)
613  else:
614    callargs = tf_inspect.getcallargs(f, *args)
615
616  formatted_callargs = '\n'.join(
617      '    {}: {}'.format(k, v) for k, v in callargs.items())
618  logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
619
620
621#
622# Public API
623#
624
625
626@tf_export('autograph.experimental.do_not_convert')
627def do_not_convert(func=None):
628  """Decorator that suppresses the conversion of a function.
629
630  Args:
631    func: function to decorate.
632
633  Returns:
634    If `func` is not None, returns a `Callable` which is equivalent to
635    `func`, but is not converted by AutoGraph.
636    If `func` is None, returns a decorator that, when invoked with a
637    single `func` argument, returns a `Callable` equivalent to the
638    above case.
639  """
640  if func is None:
641    return do_not_convert
642
643  def wrapper(*args, **kwargs):
644    with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
645      return func(*args, **kwargs)
646
647  if inspect.isfunction(func) or inspect.ismethod(func):
648    wrapper = functools.update_wrapper(wrapper, func)
649
650  return autograph_artifact(wrapper)
651
652
653# TODO(mdan): Make private.
654def convert(recursive=False,
655            optional_features=None,
656            user_requested=True,
657            conversion_ctx=ag_ctx.NullCtx()):
658  """Decorator that compiles a function to use TensorFlow ops.
659
660  The decorator is dynamic - it recompiles the target whenever the decorated
661  function is called. This means the parameter values are known at conversion.
662  It also means that repeated calls with different types of parameters will be
663  correctly processed.
664
665  Args:
666    recursive: bool, whether to recursively convert any functions or classes
667      that the converted function may use.
668    optional_features: converted.Feature, allows toggling optional or
669      experimental features. When set to None, only the core features are
670      enabled.
671    user_requested: bool, whether this is a function that the user explicitly
672      asked to be converted. See ConversionOptions.user_requested.
673    conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
674      which `f` is used.
675
676  Returns:
677    Callable, a decorator that converts the given function into an equivalent
678    function that uses TensorFlow ops.
679  """
680
681  def decorator(f):
682    """Decorator implementation."""
683
684    def wrapper(*args, **kwargs):
685      """Wrapper that calls the converted version of f."""
686      options = converter.ConversionOptions(
687          recursive=recursive,
688          user_requested=user_requested,
689          optional_features=optional_features)
690      try:
691        with conversion_ctx:
692          return converted_call(f, args, kwargs, options=options)
693      except Exception as e:  # pylint:disable=broad-except
694        if hasattr(e, 'ag_error_metadata'):
695          raise e.ag_error_metadata.to_exception(e)
696        else:
697          raise
698
699    if inspect.isfunction(f) or inspect.ismethod(f):
700      wrapper = functools.update_wrapper(wrapper, f)
701
702    decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
703    return autograph_artifact(decorated_wrapper)
704
705  return decorator
706
707
708# pylint:disable=line-too-long
709@tf_export('autograph.to_graph', v1=[])
710def to_graph(entity, recursive=True, experimental_optional_features=None):
711  """Converts a Python entity into a TensorFlow graph.
712
713  Also see: `tf.autograph.to_code`, `tf.function`.
714
715  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
716  Python code to TensorFlow graph code. It does not implement any caching,
717  variable management or create any actual ops, and is best used where greater
718  control over the generated TensorFlow graph is desired. Another difference
719  from `tf.function` is that `to_graph` will not wrap the graph into a
720  TensorFlow function or a Python callable. Internally, `tf.function` uses
721  `to_graph`.
722
723  Example usage:
724
725  >>> def f(x):
726  ...   if x > 0:
727  ...     y = x * x
728  ...   else:
729  ...     y = -x
730  ...   return y
731  ...
732  >>> converted_f = to_graph(f)
733  >>> x = tf.constant(2)
734  >>> converted_f(x)  # converted_foo is like a TensorFlow Op.
735  <tf.Tensor: shape=(), dtype=int32, numpy=4>
736
737  Supported Python entities include:
738    * functions
739    * classes
740    * object methods
741
742  Functions are converted into new functions with converted code.
743
744  Classes are converted by generating a new class whose methods use converted
745  code.
746
747  Methods are converted into unbound function that have an additional first
748  argument called `self`.
749
750  For a tutorial, see the
751  [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function).
752  For more detailed information, see the
753  [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md).
754
755  Args:
756    entity: Python callable or class to convert.
757    recursive: Whether to recursively convert any functions that the converted
758      function may call.
759    experimental_optional_features: `None`, a tuple of, or a single
760      `tf.autograph.experimental.Feature` value.
761
762  Returns:
763    Same as `entity`, the converted Python function or class.
764
765  Raises:
766    ValueError: If the entity could not be converted.
767  """
768  try:
769    program_ctx = converter.ProgramContext(
770        options=converter.ConversionOptions(
771            recursive=recursive,
772            user_requested=True,
773            optional_features=experimental_optional_features))
774    return autograph_artifact(_convert_actual(entity, program_ctx))
775  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
776    logging.error(1, 'Error converting %s', entity, exc_info=True)
777    raise ConversionError('converting {}: {}: {}'.format(
778        entity, e.__class__.__name__, str(e)))
779
780
781@tf_export(v1=['autograph.to_graph'])
782def to_graph_v1(entity,
783                recursive=True,
784                arg_values=None,
785                arg_types=None,
786                experimental_optional_features=None):
787  """Converts a Python entity into a TensorFlow graph.
788
789  Also see: `tf.autograph.to_code`, `tf.function`.
790
791  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
792  Python code to TensorFlow graph code. It does not implement any caching,
793  variable management or create any actual ops, and is best used where greater
794  control over the generated TensorFlow graph is desired. Another difference
795  from `tf.function` is that `to_graph` will not wrap the graph into a
796  TensorFlow function or a Python callable. Internally, `tf.function` uses
797  `to_graph`.
798
799  _Example Usage_
800
801  ```python
802    def foo(x):
803      if x > 0:
804        y = x * x
805      else:
806        y = -x
807      return y
808
809    converted_foo = to_graph(foo)
810
811    x = tf.constant(1)
812    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
813    assert is_tensor(y)
814  ```
815
816  Supported Python entities include:
817    * functions
818    * classes
819    * object methods
820
821  Functions are converted into new functions with converted code.
822
823  Classes are converted by generating a new class whose methods use converted
824  code.
825
826  Methods are converted into unbound function that have an additional first
827  argument called `self`.
828
829  Args:
830    entity: Python callable or class to convert.
831    recursive: Whether to recursively convert any functions that the converted
832      function may call.
833    arg_values: Deprecated.
834    arg_types: Deprecated.
835    experimental_optional_features: `None`, a tuple of, or a single
836      `tf.autograph.experimental.Feature` value.
837
838  Returns:
839    Same as `entity`, the converted Python function or class.
840
841  Raises:
842    ValueError: If the entity could not be converted.
843  """
844  del arg_types
845  del arg_values
846  return to_graph(
847      entity,
848      recursive=recursive,
849      experimental_optional_features=experimental_optional_features)
850
851
852@tf_export(v1=['autograph.to_code'])
853def to_code_v1(entity,
854               recursive=True,
855               arg_values=None,
856               arg_types=None,
857               indentation='  ',
858               experimental_optional_features=None):
859  """Returns the source code generated by AutoGraph, as a string.
860
861  Example usage:
862
863  >>> def f(x):
864  ...   if x < 0:
865  ...     x = -x
866  ...   return x
867  >>> tf.autograph.to_code(f)
868  "...def tf__f(x):..."
869
870  Also see: `tf.autograph.to_graph`.
871
872  Note: If a function has been decorated with `tf.function`, pass its
873  underlying Python function, rather than the callable that `tf.function
874  creates:
875
876  >>> @tf.function
877  ... def f(x):
878  ...   if x < 0:
879  ...     x = -x
880  ...   return x
881  >>> tf.autograph.to_code(f.python_function)
882  "...def tf__f(x):..."
883
884  Args:
885    entity: Python callable or class.
886    recursive: Whether to recursively convert any functions that the converted
887      function may call.
888    arg_values: Deprecated.
889    arg_types: Deprecated.
890    indentation: Deprecated.
891    experimental_optional_features: `None`, a tuple of, or a single
892      `tf.autograph.experimental.Feature` value.
893
894  Returns:
895    The converted code as string.
896  """
897  del arg_values
898  del arg_types
899  del indentation
900  return to_code(
901      entity,
902      recursive=recursive,
903      experimental_optional_features=experimental_optional_features)
904
905
906@tf_export('autograph.to_code', v1=[])
907def to_code(entity, recursive=True, experimental_optional_features=None):
908  """Returns the source code generated by AutoGraph, as a string.
909
910  Example usage:
911
912  >>> def f(x):
913  ...   if x < 0:
914  ...     x = -x
915  ...   return x
916  >>> tf.autograph.to_code(f)
917  "...def tf__f(x):..."
918
919  Also see: `tf.autograph.to_graph`.
920
921  Note: If a function has been decorated with `tf.function`, pass its
922  underlying Python function, rather than the callable that `tf.function
923  creates:
924
925  >>> @tf.function
926  ... def f(x):
927  ...   if x < 0:
928  ...     x = -x
929  ...   return x
930  >>> tf.autograph.to_code(f.python_function)
931  "...def tf__f(x):..."
932
933  Args:
934    entity: Python callable or class to convert.
935    recursive: Whether to recursively convert any functions that the converted
936      function may call.
937    experimental_optional_features: `None`, a tuple of, or a single
938      `tf.autograph.experimental.Feature` value.
939
940  Returns:
941    The converted code as string.
942  """
943  source = tf_inspect.getsource(
944      to_graph(
945          entity,
946          recursive=recursive,
947          experimental_optional_features=experimental_optional_features))
948  return textwrap.dedent(source)
949
950
951_TRANSPILER = PyToTF()
952