• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module contains the user- and codegen-facing API for AutoGraph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import imp
23import inspect
24import os
25import sys
26import textwrap
27import traceback
28
29import six
30
31from tensorflow.python.autograph import operators
32from tensorflow.python.autograph import utils
33from tensorflow.python.autograph.converters import asserts
34from tensorflow.python.autograph.converters import break_statements
35from tensorflow.python.autograph.converters import call_trees
36from tensorflow.python.autograph.converters import conditional_expressions
37from tensorflow.python.autograph.converters import continue_statements
38from tensorflow.python.autograph.converters import control_flow
39from tensorflow.python.autograph.converters import directives
40from tensorflow.python.autograph.converters import functions
41from tensorflow.python.autograph.converters import lists
42from tensorflow.python.autograph.converters import logical_expressions
43from tensorflow.python.autograph.converters import return_statements
44from tensorflow.python.autograph.converters import slices
45from tensorflow.python.autograph.converters import variables
46from tensorflow.python.autograph.core import ag_ctx
47from tensorflow.python.autograph.core import converter
48from tensorflow.python.autograph.core import function_wrappers
49from tensorflow.python.autograph.core import unsupported_features_checker
50from tensorflow.python.autograph.impl import conversion
51from tensorflow.python.autograph.lang import special_functions
52from tensorflow.python.autograph.operators import py_builtins
53from tensorflow.python.autograph.pyct import anno
54from tensorflow.python.autograph.pyct import cfg
55from tensorflow.python.autograph.pyct import error_utils
56from tensorflow.python.autograph.pyct import errors
57from tensorflow.python.autograph.pyct import inspect_utils
58from tensorflow.python.autograph.pyct import origin_info
59from tensorflow.python.autograph.pyct import qual_names
60from tensorflow.python.autograph.pyct import transpiler
61from tensorflow.python.autograph.pyct.static_analysis import activity
62from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
63from tensorflow.python.autograph.utils import ag_logging as logging
64from tensorflow.python.eager import function
65from tensorflow.python.framework import errors_impl
66from tensorflow.python.util import tf_decorator
67from tensorflow.python.util import tf_inspect
68from tensorflow.python.util import tf_stack
69from tensorflow.python.util.tf_export import tf_export
70
71
72def is_autograph_strict_conversion_mode():
73  return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0
74
75
76#
77# Error handling
78#
79
80
81# TODO(mdan): Export this symbol.
82class AutoGraphError(errors.PyCTError):
83  """Base class for all AutoGraph exceptions."""
84  pass
85
86
87class ConversionError(AutoGraphError):
88  """Raised during the conversion process."""
89  pass
90
91
92class StagingError(AutoGraphError):
93  """Raised during the staging (i.e. Python execution) of converted code."""
94  pass
95
96
97class _ErrorMetadata(error_utils.ErrorMetadataBase):
98  """AutoGraph-specific error metadata. See base class."""
99
100  def create_exception(self, source_error):
101    preferred_type = type(source_error)
102    if issubclass(preferred_type, errors_impl.OpError):
103      # Best-effort unpacking of OpError exceptions.
104      # TODO(mdan): Use a mechanism that is more future-proof.
105      init_argspec = tf_inspect.getfullargspec(preferred_type.__init__)
106      message = self.get_message()
107      init_args = tuple(init_argspec.args)
108      # At the time of this writing, TF errors either take 3 or 4 arguments,
109      # with the fourth being error_code.
110      if init_args == ('self', 'node_def', 'op', 'message', 'error_code'):
111        return preferred_type(
112            node_def=source_error.node_def,
113            op=source_error.op,
114            message=message,
115            error_code=self.error_code)
116      elif init_args == ('self', 'node_def', 'op', 'message'):
117        if 'error_code' in init_argspec.kwonlyargs:
118          return preferred_type(
119              node_def=source_error.node_def,
120              op=source_error.op,
121              message=message,
122              errro_code=self.error_code)
123        else:
124          return preferred_type(
125              node_def=source_error.node_def,
126              op=source_error.op,
127              message=message)
128
129    elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError,
130                            StagingError, errors_impl.InaccessibleTensorError,
131                            errors_impl.OperatorNotAllowedInGraphError):
132      return preferred_type(self.get_message())
133
134    exc = super(_ErrorMetadata, self).create_exception(source_error)
135    if exc is not None:
136      return exc
137
138    # Note: While changing an error's message property to change the message it
139    # displays will probably work a lot of times, there is no standard way in
140    # Python to do that. The safest way is therefore to create a new exception.
141    # For user defined exceptions, we could define an interface that allowed
142    # them to work under this mechanism.
143    return StagingError(self.get_message())
144
145
146def _attach_error_metadata(e, f):
147  """Augments an error with the metadata necessary for rewrite."""
148  if hasattr(e, 'ag_pass_through'):
149    return
150
151  metadata = getattr(e, 'ag_error_metadata', None)
152  source_map = f.ag_source_map
153
154  if metadata is None:
155    logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
156    message = '{}: {}'.format(e.__class__.__name__, e)
157  else:
158    message = None
159
160  cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
161
162  e.ag_error_metadata = _ErrorMetadata(
163      cause_tb, metadata, message, source_map, __file__)
164
165
166class StackTraceMapper(tf_stack.StackTraceMapper):
167  """Remaps generated code to code it originated from."""
168
169  def __init__(self, converted_fn):
170    super().__init__()
171    self._source_map = converted_fn.ag_source_map
172    # This may be called repeatedly: once on entry, by the superclass, then by
173    # each child context manager.
174    self._cached_map = None
175
176  def get_effective_source_map(self):
177    if self._cached_map is not None:
178      return self._cached_map
179
180    parent_map = self.parent.get_effective_source_map()
181
182    effective_source_map = {}
183    for loc, origin in self._source_map.items():
184      effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename,
185                                                          origin.loc.lineno,
186                                                          origin.function_name)
187
188    for key, value in parent_map.items():
189      filename, lineno, _ = value
190      value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
191      if value_loc in self._source_map:
192        origin = self._source_map[value_loc]
193        effective_source_map[key] = (origin.loc.filename, origin.loc.lineno,
194                                     origin.function_name)
195      else:
196        effective_source_map[key] = value
197
198    self._cached_map = effective_source_map
199    return effective_source_map
200
201
202#
203# Actual source code transformation
204#
205
206
207class PyToTF(transpiler.PyToPy):
208  """The TensorFlow AutoGraph transformer."""
209
210  def __init__(self):
211    super(PyToTF, self).__init__()
212    self._extra_locals = None
213
214  def get_transformed_name(self, node):
215    return 'tf__' + super(PyToTF, self).get_transformed_name(node)
216
217  def get_extra_locals(self):
218    if self._extra_locals is None:
219      # TODO(mdan): Move into core or replace with an actual importable module.
220      # Craft a module that exposes the external API as well as certain
221      # internal modules.
222      ag_internal = imp.new_module('autograph')
223      ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
224      ag_internal.ConversionOptions = converter.ConversionOptions
225      ag_internal.STD = converter.STANDARD_OPTIONS
226      ag_internal.Feature = converter.Feature
227      ag_internal.utils = utils
228      ag_internal.FunctionScope = function_wrappers.FunctionScope
229      ag_internal.with_function_scope = function_wrappers.with_function_scope
230      # TODO(mdan): Add safeguards against name clashes.
231      # We don't want to create a submodule because we want the operators to be
232      # accessible as ag__.<operator>
233      ag_internal.__dict__.update(special_functions.__dict__)
234      ag_internal.__dict__.update(operators.__dict__)
235
236      self._extra_locals = {'ag__': ag_internal}
237    return self._extra_locals
238
239  def get_caching_key(self, ctx):
240    return ctx.options
241
242  def initial_analysis(self, node, ctx):
243    graphs = cfg.build(node)
244    node = qual_names.resolve(node)
245    node = activity.resolve(node, ctx, None)
246    node = reaching_definitions.resolve(node, ctx, graphs)
247    anno.dup(
248        node,
249        {
250            anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
251        },
252    )
253    return node
254
255  def transform_ast(self, node, ctx):
256    unsupported_features_checker.verify(node)
257    node = self.initial_analysis(node, ctx)
258
259    node = functions.transform(node, ctx)
260    node = directives.transform(node, ctx)
261    node = break_statements.transform(node, ctx)
262    if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
263      node = asserts.transform(node, ctx)
264    # Note: sequencing continue canonicalization before for loop one avoids
265    # dealing with the extra loop increment operation that the for
266    # canonicalization creates.
267    node = continue_statements.transform(node, ctx)
268    node = return_statements.transform(node, ctx)
269    if ctx.user.options.uses(converter.Feature.LISTS):
270      node = lists.transform(node, ctx)
271      node = slices.transform(node, ctx)
272    node = call_trees.transform(node, ctx)
273    node = control_flow.transform(node, ctx)
274    node = conditional_expressions.transform(node, ctx)
275    node = logical_expressions.transform(node, ctx)
276    node = variables.transform(node, ctx)
277    return node
278
279
280def _convert_actual(entity, program_ctx):
281  """Applies AutoGraph to entity."""
282
283  # TODO(mdan): Put these extra fields inside __autograph_info__.
284  if not hasattr(entity, '__code__'):
285    raise ValueError('Cannot apply autograph to a function that doesn\'t '
286                     'expose a __code__ object. If this is a @tf.function,'
287                     ' try passing f.python_function instead.')
288
289  transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
290
291  assert not hasattr(transformed, 'ag_module')
292  assert not hasattr(transformed, 'ag_source_map')
293  transformed.ag_module = module
294  transformed.ag_source_map = source_map
295  return transformed
296
297
298#
299# Generated code support
300#
301
302
303def autograph_artifact(entity, extras=None):
304  if inspect.ismethod(entity):
305    setattr(entity.__func__, 'autograph_info__', extras)
306  else:
307    setattr(entity, 'autograph_info__', extras)
308  return entity
309
310
311def is_autograph_artifact(entity):
312  return hasattr(entity, 'autograph_info__')
313
314
315def converted_call(f,
316                   args,
317                   kwargs,
318                   caller_fn_scope=None,
319                   options=None):
320  """Converts a function call inline.
321
322  For internal use only.
323
324  Note: The argument list is optimized for readability of generated code, which
325  may look like this:
326
327    ag__.converted_call(f, (arg1, arg2), None, fscope)
328    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
329    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)
330
331  Args:
332    f: The function to convert.
333    args: Tuple, the original positional arguments of f
334    kwargs: Optional[Dict], the original keyword arguments of f
335    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
336      scope of the converted function in which this call was originally made.
337    options: Optional[converter.ConversionOptions], conversion options. If not
338      specified, the value of caller_fn_scope.callopts is used. Either options
339      or caller_fn_scope must be present.
340
341  Returns:
342    Any, the result of executing a possibly-converted `f` with the given
343      arguments.
344  """
345  logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f, args,
346              kwargs)
347
348  if options is None:
349    if caller_fn_scope is None:
350      raise ValueError('either caller_fn_scope or options must have a value')
351    options = caller_fn_scope.callopts
352
353  if conversion.is_in_allowlist_cache(f, options):
354    logging.log(2, 'Allowlisted %s: from cache', f)
355    return _call_unconverted(f, args, kwargs, options, False)
356
357  if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
358    logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
359    return _call_unconverted(f, args, kwargs, options, False)
360
361  if is_autograph_artifact(f):
362    logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f)
363    return _call_unconverted(f, args, kwargs, options)
364
365  # If this is a partial, unwrap it and redo all the checks.
366  if isinstance(f, functools.partial):
367    new_kwargs = {}
368    if f.keywords is not None:
369      # Use copy to avoid mutating the underlying keywords.
370      new_kwargs = f.keywords.copy()
371    if kwargs is not None:
372      new_kwargs.update(kwargs)
373    new_args = f.args + args
374    logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
375                new_kwargs)
376    return converted_call(
377        f.func,
378        new_args,
379        new_kwargs,
380        caller_fn_scope=caller_fn_scope,
381        options=options)
382
383  if inspect_utils.isbuiltin(f):
384    if f is eval:
385      return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
386    if f is super:
387      return py_builtins.super_in_original_context(f, args, caller_fn_scope)
388    if f is globals:
389      return py_builtins.globals_in_original_context(caller_fn_scope)
390    if f is locals:
391      return py_builtins.locals_in_original_context(caller_fn_scope)
392    if kwargs:
393      return py_builtins.overload_of(f)(*args, **kwargs)
394    else:
395      return py_builtins.overload_of(f)(*args)
396
397  if conversion.is_unsupported(f):
398    return _call_unconverted(f, args, kwargs, options)
399
400  if not options.user_requested and conversion.is_allowlisted(f):
401    return _call_unconverted(f, args, kwargs, options)
402
403  # internal_convert_user_code is for example turned off when issuing a dynamic
404  # call conversion from generated code while in nonrecursive mode. In that
405  # case we evidently don't want to recurse, but we still have to convert
406  # things like builtins.
407  if not options.internal_convert_user_code:
408    return _call_unconverted(f, args, kwargs, options)
409
410  try:
411    if inspect.ismethod(f) or inspect.isfunction(f):
412      target_entity = f
413      effective_args = args
414
415      f_self = getattr(f, '__self__', None)
416      if f_self is not None:
417        if isinstance(f_self, function.TfMethodTarget):
418          f_self = f_self.target
419        effective_args = (f_self,) + effective_args
420
421    elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
422      # Callable objects. Dunder methods have special lookup rules, see:
423      # https://docs.python.org/3/reference/datamodel.html#specialnames
424      # TODO(mdan): Recurse into converted_call to simplify other verifications.
425      # This should be handled in the same way as partials.
426      target_entity = f.__class__.__call__
427      effective_args = (f,) + args
428
429    else:
430      target_entity = f
431      raise NotImplementedError('unknown callable type "%s"' % type(f))
432
433  except Exception as e:  # pylint:disable=broad-except
434    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
435    if is_autograph_strict_conversion_mode():
436      raise
437    return _fall_back_unconverted(f, args, kwargs, options, e)
438
439  if not hasattr(target_entity, '__code__'):
440    logging.log(2, 'Permanently allowed: %s: native binding',
441                target_entity)
442    return _call_unconverted(f, args, kwargs, options)
443  elif (hasattr(target_entity.__code__, 'co_filename') and
444        target_entity.__code__.co_filename == '<string>'):
445    # TODO(mdan): __globals__['txt'] might work in Py3.
446    logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)',
447                target_entity)
448    return _call_unconverted(f, args, kwargs, options)
449
450  try:
451    program_ctx = converter.ProgramContext(options=options)
452    converted_f = _convert_actual(target_entity, program_ctx)
453    if logging.has_verbosity(2):
454      _log_callargs(converted_f, effective_args, kwargs)
455  except Exception as e:  # pylint:disable=broad-except
456    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
457    if is_autograph_strict_conversion_mode():
458      raise
459    return _fall_back_unconverted(f, args, kwargs, options, e)
460
461  with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
462    try:
463      if kwargs is not None:
464        result = converted_f(*effective_args, **kwargs)
465      else:
466        result = converted_f(*effective_args)
467    except Exception as e:
468      _attach_error_metadata(e, converted_f)
469      raise
470
471  return result
472
473
474def _call_unconverted(f, args, kwargs, options, update_cache=True):
475  """Calls the original function without converting with AutoGraph."""
476  if update_cache:
477    conversion.cache_allowlisted(f, options)
478
479  if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
480    return f.__self__.call(args, kwargs)
481
482  if kwargs is not None:
483    return f(*args, **kwargs)
484  return f(*args)
485
486
487def _fall_back_unconverted(f, args, kwargs, options, exc):
488  """Falls back to calling the function unconverted, in case of error."""
489  # TODO(mdan): Consider adding an internal metric.
490  warning_template = (
491      'AutoGraph could not transform %s and will run it as-is.\n'
492      '%s'
493      'Cause: %s\n'
494      'To silence this warning, decorate the function with'
495      ' @tf.autograph.experimental.do_not_convert')
496  if isinstance(exc, errors.UnsupportedLanguageElementError):
497    if not conversion.is_in_allowlist_cache(f, options):
498      logging.warn(warning_template, f, '', exc)
499  else:
500    file_bug_message = (
501        'Please report this to the TensorFlow team. When filing the bug, set'
502        ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
503        ' attach the full output.\n')
504    logging.warn(warning_template, f, file_bug_message, exc)
505
506  return _call_unconverted(f, args, kwargs, options)
507
508
509#
510# TensorFlow integration
511#
512
513
514@tf_export('__internal__.autograph.tf_convert', v1=[])
515def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
516  """Decorator that applies AutoGraph to a function.
517
518  Use in internal APIs.
519
520  This API is suitable for high order functions internal to the TensorFlow API,
521  and more generally any function to which AutoGraph is not applied.
522
523  Guidance: `convert` was a decorator meant for use directly by developers, but
524  most of today's uses go through `tf.function`. `tf_convert` is to be called
525  from high order functions internal to TF. By default, all the internal
526  TensorFlow functions are skipped when AutoGraph processes the code. This may
527  lead to user-supplied functions to be incorrectly skipped as well.
528  `tf_convert` helps avoid that. See the following example for more details.
529
530  ```
531  =====tf_internal_module.py=====
532
533  def unconverted(input_fn):
534    return input_fn()
535
536  def converted(input_fn):
537    return tf.__internal__.autograph.tf_convert(
538       input_fn, ctx=tf.__internal__.autograph.control_status_ctx())()
539
540  ======user_module.py======
541
542  @tf.function
543  def foo(input_fn)
544    return unconverted(input_fn)
545
546  @tf.function
547  def bar(input_fn)
548    return converted(input_fn)
549
550  @tf.function(autograph=False)
551  def baz(input_fn)
552    return converted(input_fn)
553  ```
554
555  The `foo` method above will execute the `input_fn` without autograph
556  conversion, while the `bar` method will run an autographed `input_fn`. The
557  `baz` method will run an unconverted `input_fn`, since `tf_convert` respect
558  the control status context.
559
560  Note that both methods in `tf_internal_module` are skipped by autograph when
561  tracing the `tf.function`. The configuration of whether a module/package
562  should be skipped by autograph is controlled in
563  tensorflow/python/autograph/core/config.py.
564
565  Args:
566    f: Callable.
567    ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
568    convert_by_default: bool, whether to use AutoGraph when the context doesn't
569      specify.
570    user_requested: bool, whether to ignore the conversion allowlist. See
571      ConversionOptions.user_requested.
572
573  Returns:
574    Either `f or the converted version of `f`.
575  """
576
577  if is_autograph_artifact(f):
578    return f
579  f_wrapper = f
580  decorators, f = tf_decorator.unwrap(f)
581
582  # TODO(mdan): Grab features from context.
583  # Note: we pass the original context through to convert to properly handle the
584  # following scenario, which can be used inside TF implementations:
585  #
586  #   ctx = ag_ctx.control_status_ctx()
587  #   @function(autograph=False)  # Low-level graph code
588  #   def inner_fn():
589  #     # The context is disabled here, but should be enabled in user user_fn
590  #     tf_convert(user_fn, ctx=ctx)
591  if ctx.status == ag_ctx.Status.ENABLED:
592    wrapper_factory = convert(
593        recursive=True, user_requested=user_requested, conversion_ctx=ctx)
594  elif ctx.status == ag_ctx.Status.DISABLED:
595    wrapper_factory = do_not_convert
596  elif ctx.status == ag_ctx.Status.UNSPECIFIED:
597    if convert_by_default:
598      wrapper_factory = convert(
599          recursive=True, user_requested=user_requested, conversion_ctx=ctx)
600    else:
601      wrapper_factory = call_with_unspecified_conversion_status
602  else:
603    assert False, 'This switch contains all possible cases!'
604  wrapper = wrapper_factory(f)
605
606  if decorators:
607    wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
608
609  return autograph_artifact(wrapper)
610
611
612def call_with_unspecified_conversion_status(func):
613  """Decorator that resets the conversion context to the unspecified status."""
614  def wrapper(*args, **kwargs):
615    with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
616      return func(*args, **kwargs)
617
618  if inspect.isfunction(func) or inspect.ismethod(func):
619    wrapper = functools.update_wrapper(wrapper, func)
620
621  return autograph_artifact(wrapper)
622
623
624def _log_callargs(f, args, kwargs):
625  """Logging helper."""
626  logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
627  if not six.PY2:
628    logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
629
630  if kwargs is not None:
631    callargs = tf_inspect.getcallargs(f, *args, **kwargs)
632  else:
633    callargs = tf_inspect.getcallargs(f, *args)
634
635  formatted_callargs = '\n'.join(
636      '    {}: {}'.format(k, v) for k, v in callargs.items())
637  logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
638
639
640#
641# Public API
642#
643
644
645@tf_export('autograph.experimental.do_not_convert')
646def do_not_convert(func=None):
647  """Decorator that suppresses the conversion of a function.
648
649  Args:
650    func: function to decorate.
651
652  Returns:
653    If `func` is not None, returns a `Callable` which is equivalent to
654    `func`, but is not converted by AutoGraph.
655    If `func` is None, returns a decorator that, when invoked with a
656    single `func` argument, returns a `Callable` equivalent to the
657    above case.
658  """
659  if func is None:
660    return do_not_convert
661
662  def wrapper(*args, **kwargs):
663    with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
664      return func(*args, **kwargs)
665
666  if inspect.isfunction(func) or inspect.ismethod(func):
667    wrapper = functools.update_wrapper(wrapper, func)
668
669  return autograph_artifact(wrapper)
670
671
672# TODO(mdan): Make private.
673def convert(recursive=False,
674            optional_features=None,
675            user_requested=True,
676            conversion_ctx=ag_ctx.NullCtx()):
677  """Decorator that compiles a function to use TensorFlow ops.
678
679  The decorator is dynamic - it recompiles the target whenever the decorated
680  function is called. This means the parameter values are known at conversion.
681  It also means that repeated calls with different types of parameters will be
682  correctly processed.
683
684  Args:
685    recursive: bool, whether to recursively convert any functions or classes
686      that the converted function may use.
687    optional_features: converted.Feature, allows toggling optional or
688      experimental features. When set to None, only the core features are
689      enabled.
690    user_requested: bool, whether this is a function that the user explicitly
691      asked to be converted. See ConversionOptions.user_requested.
692    conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
693      which `f` is used.
694
695  Returns:
696    Callable, a decorator that converts the given function into an equivalent
697    function that uses TensorFlow ops.
698  """
699
700  def decorator(f):
701    """Decorator implementation."""
702
703    def wrapper(*args, **kwargs):
704      """Wrapper that calls the converted version of f."""
705      options = converter.ConversionOptions(
706          recursive=recursive,
707          user_requested=user_requested,
708          optional_features=optional_features)
709      try:
710        with conversion_ctx:
711          return converted_call(f, args, kwargs, options=options)
712      except Exception as e:  # pylint:disable=broad-except
713        if hasattr(e, 'ag_error_metadata'):
714          raise e.ag_error_metadata.to_exception(e)
715        else:
716          raise
717
718    if inspect.isfunction(f) or inspect.ismethod(f):
719      wrapper = functools.update_wrapper(wrapper, f)
720
721    decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
722    return autograph_artifact(decorated_wrapper)
723
724  return decorator
725
726
727# pylint:disable=line-too-long
728@tf_export('autograph.to_graph', v1=[])
729def to_graph(entity, recursive=True, experimental_optional_features=None):
730  """Converts a Python entity into a TensorFlow graph.
731
732  Also see: `tf.autograph.to_code`, `tf.function`.
733
734  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
735  Python code to TensorFlow graph code. It does not implement any caching,
736  variable management or create any actual ops, and is best used where greater
737  control over the generated TensorFlow graph is desired. Another difference
738  from `tf.function` is that `to_graph` will not wrap the graph into a
739  TensorFlow function or a Python callable. Internally, `tf.function` uses
740  `to_graph`.
741
742  Example usage:
743
744  >>> def f(x):
745  ...   if x > 0:
746  ...     y = x * x
747  ...   else:
748  ...     y = -x
749  ...   return y
750  ...
751  >>> converted_f = to_graph(f)
752  >>> x = tf.constant(2)
753  >>> converted_f(x)  # converted_foo is like a TensorFlow Op.
754  <tf.Tensor: shape=(), dtype=int32, numpy=4>
755
756  Supported Python entities include:
757    * functions
758    * classes
759    * object methods
760
761  Functions are converted into new functions with converted code.
762
763  Classes are converted by generating a new class whose methods use converted
764  code.
765
766  Methods are converted into unbound function that have an additional first
767  argument called `self`.
768
769  For a tutorial, see the
770  [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function).
771  For more detailed information, see the
772  [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md).
773
774  Args:
775    entity: Python callable or class to convert.
776    recursive: Whether to recursively convert any functions that the converted
777      function may call.
778    experimental_optional_features: `None`, a tuple of, or a single
779      `tf.autograph.experimental.Feature` value.
780
781  Returns:
782    Same as `entity`, the converted Python function or class.
783
784  Raises:
785    ValueError: If the entity could not be converted.
786  """
787  try:
788    program_ctx = converter.ProgramContext(
789        options=converter.ConversionOptions(
790            recursive=recursive,
791            user_requested=True,
792            optional_features=experimental_optional_features))
793    return autograph_artifact(_convert_actual(entity, program_ctx))
794  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
795    logging.error(1, 'Error converting %s', entity, exc_info=True)
796    raise ConversionError('converting {}: {}: {}'.format(
797        entity, e.__class__.__name__, str(e)))
798
799
800@tf_export(v1=['autograph.to_graph'])
801def to_graph_v1(entity,
802                recursive=True,
803                arg_values=None,
804                arg_types=None,
805                experimental_optional_features=None):
806  """Converts a Python entity into a TensorFlow graph.
807
808  Also see: `tf.autograph.to_code`, `tf.function`.
809
810  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
811  Python code to TensorFlow graph code. It does not implement any caching,
812  variable management or create any actual ops, and is best used where greater
813  control over the generated TensorFlow graph is desired. Another difference
814  from `tf.function` is that `to_graph` will not wrap the graph into a
815  TensorFlow function or a Python callable. Internally, `tf.function` uses
816  `to_graph`.
817
818  _Example Usage_
819
820  ```python
821    def foo(x):
822      if x > 0:
823        y = x * x
824      else:
825        y = -x
826      return y
827
828    converted_foo = to_graph(foo)
829
830    x = tf.constant(1)
831    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
832    assert is_tensor(y)
833  ```
834
835  Supported Python entities include:
836    * functions
837    * classes
838    * object methods
839
840  Functions are converted into new functions with converted code.
841
842  Classes are converted by generating a new class whose methods use converted
843  code.
844
845  Methods are converted into unbound function that have an additional first
846  argument called `self`.
847
848  Args:
849    entity: Python callable or class to convert.
850    recursive: Whether to recursively convert any functions that the converted
851      function may call.
852    arg_values: Deprecated.
853    arg_types: Deprecated.
854    experimental_optional_features: `None`, a tuple of, or a single
855      `tf.autograph.experimental.Feature` value.
856
857  Returns:
858    Same as `entity`, the converted Python function or class.
859
860  Raises:
861    ValueError: If the entity could not be converted.
862  """
863  del arg_types
864  del arg_values
865  return to_graph(
866      entity,
867      recursive=recursive,
868      experimental_optional_features=experimental_optional_features)
869
870
871@tf_export(v1=['autograph.to_code'])
872def to_code_v1(entity,
873               recursive=True,
874               arg_values=None,
875               arg_types=None,
876               indentation='  ',
877               experimental_optional_features=None):
878  """Returns the source code generated by AutoGraph, as a string.
879
880  Example usage:
881
882  >>> def f(x):
883  ...   if x < 0:
884  ...     x = -x
885  ...   return x
886  >>> tf.autograph.to_code(f)
887  "...def tf__f(x):..."
888
889  Also see: `tf.autograph.to_graph`.
890
891  Note: If a function has been decorated with `tf.function`, pass its
892  underlying Python function, rather than the callable that `tf.function
893  creates:
894
895  >>> @tf.function
896  ... def f(x):
897  ...   if x < 0:
898  ...     x = -x
899  ...   return x
900  >>> tf.autograph.to_code(f.python_function)
901  "...def tf__f(x):..."
902
903  Args:
904    entity: Python callable or class.
905    recursive: Whether to recursively convert any functions that the converted
906      function may call.
907    arg_values: Deprecated.
908    arg_types: Deprecated.
909    indentation: Deprecated.
910    experimental_optional_features: `None`, a tuple of, or a single
911      `tf.autograph.experimental.Feature` value.
912
913  Returns:
914    The converted code as string.
915  """
916  del arg_values
917  del arg_types
918  del indentation
919  return to_code(
920      entity,
921      recursive=recursive,
922      experimental_optional_features=experimental_optional_features)
923
924
925@tf_export('autograph.to_code', v1=[])
926def to_code(entity, recursive=True, experimental_optional_features=None):
927  """Returns the source code generated by AutoGraph, as a string.
928
929  Example usage:
930
931  >>> def f(x):
932  ...   if x < 0:
933  ...     x = -x
934  ...   return x
935  >>> tf.autograph.to_code(f)
936  "...def tf__f(x):..."
937
938  Also see: `tf.autograph.to_graph`.
939
940  Note: If a function has been decorated with `tf.function`, pass its
941  underlying Python function, rather than the callable that `tf.function
942  creates:
943
944  >>> @tf.function
945  ... def f(x):
946  ...   if x < 0:
947  ...     x = -x
948  ...   return x
949  >>> tf.autograph.to_code(f.python_function)
950  "...def tf__f(x):..."
951
952  Args:
953    entity: Python callable or class to convert.
954    recursive: Whether to recursively convert any functions that the converted
955      function may call.
956    experimental_optional_features: `None`, a tuple of, or a single
957      `tf.autograph.experimental.Feature` value.
958
959  Returns:
960    The converted code as string.
961  """
962  source = tf_inspect.getsource(
963      to_graph(
964          entity,
965          recursive=recursive,
966          experimental_optional_features=experimental_optional_features))
967  return textwrap.dedent(source)
968
969
970_TRANSPILER = PyToTF()
971