• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module contains the user-facing API for AutoGraph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import functools
24import os
25import pdb
26import sys
27
28from enum import Enum
29
30# pylint:disable=g-bad-import-order
31import numpy as np
32import six
33# pylint:enable=g-bad-import-order
34
35
36from tensorflow.python.autograph.core import converter
37from tensorflow.python.autograph.impl import conversion
38from tensorflow.python.autograph.operators import py_builtins
39from tensorflow.python.autograph.pyct import compiler
40from tensorflow.python.autograph.pyct import errors
41from tensorflow.python.autograph.pyct import inspect_utils
42from tensorflow.python.autograph.utils import ag_logging as logging
43from tensorflow.python.autograph.utils import py_func
44from tensorflow.python.framework import tensor_util
45from tensorflow.python.util import nest
46from tensorflow.python.util import tf_decorator
47from tensorflow.python.util import tf_inspect
48from tensorflow.python.util.tf_export import tf_export
49
50
51def is_autograph_strict_conversion_mode():
52  return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0
53
54
55# TODO(mdan): Properly document the type hints.
56# TODO(mdan): Reduce the type hint information to (module, type).
57# (currently we require (module + class name, type))
58
59
60# TODO(mdan): This should behave like to_graph (e.g. convert statically).
61# TODO(znado): Make an alias so can write Verbosity directly without needing
62# to write converter.
63def convert(
64    recursive=False,
65    optional_features=converter.Feature.ALL):
66  """Decorator that compiles a function to use TensorFlow ops.
67
68  The decorator is dynamic - it recompiles the target whenever the decorated
69  function is called. This means the parameter values are known at conversion.
70  It also means that repeated calls with different types of parameters will be
71  correctly processed.
72
73  Args:
74    recursive: bool, whether to recursively convert any functions or classes
75      that the converted function may use.
76    optional_features: converted.Feature, allows toggling optional or
77      experimental features. When set to None, only the core features are
78      enabled.
79
80  Returns:
81    Callable, a decorator that converts the given function into an equivalent
82    function that uses TensorFlow ops.
83  """
84
85  def decorator(f):
86    """Decorator implementation."""
87
88    @functools.wraps(f)
89    def wrapper(*args, **kwargs):
90      return converted_call(
91          f, None,
92          converter.ConversionOptions(
93              recursive=recursive,
94              force_conversion=True,
95              optional_features=optional_features,
96          ), args, kwargs)
97
98    wrapper = tf_decorator.make_decorator(f, wrapper)
99
100    # Sometimes the decorator is just desugared, making it impossible to detect.
101    # This attribute makes detection easier.
102    setattr(wrapper, '__ag_compiled', True)
103    return wrapper
104
105  return decorator
106
107
108class RunMode(Enum):
109  """Specifies the way a converted function or method should be executed in TF.
110
111  Attributes:
112   * GRAPH: Call this function directly, as-is. This is suitable for functions
113       that were already designed for TF graphs and contain ops.
114   * PY_FUNC: Wrap this function into a py_func op. This is suitable for code
115       that will only run correctly in Python, for example code that renders
116       to the display, reads keyboard input, etc.
117  """
118  GRAPH = 1
119  PY_FUNC = 2
120
121
122def do_not_convert_internal(f):
123  """Decorator that marks internal functions which do not need conversion."""
124  setattr(f, '__ag_compiled', True)
125  return f
126
127
128def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
129  """Decorator that suppresses the conversion of a function.
130
131  See also: docs/pyfunc_dtypes.md
132
133  Args:
134    run_as: RunMode, specifies how to use the function in TensorFlow.
135    return_dtypes: Optional[Iterable[ Union[tf.DType,
136      utils.py_func.MatchDType]]], the return data types of the converted
137      function, if run_as is RunMode.PY_FUNC. Ignored otherwise. May be set to
138      None if the function has no return values.
139
140  Returns:
141    Callable, a decorator that wraps the original function.
142  """
143
144  def decorator(f):
145    """Decorator implementation."""
146
147    @functools.wraps(f)
148    def graph_wrapper(*args, **kwargs):
149      return f(*args, **kwargs)
150
151    @functools.wraps(f)
152    def py_func_wrapper(*args, **kwargs):
153      if kwargs:
154        raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
155      # TODO(mdan): Add support for kwargs.
156      return py_func.wrap_py_func(
157          f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
158
159    if run_as == RunMode.GRAPH:
160      wrapper = graph_wrapper
161    elif run_as == RunMode.PY_FUNC:
162      wrapper = py_func_wrapper
163    else:
164      raise ValueError('unknown value for run_as: %s' % run_as)
165
166    setattr(wrapper, '__ag_compiled', True)
167    return wrapper
168
169  return decorator
170
171
172def _call_unconverted(f, args, kwargs):
173  """Calls the original function without converting with AutoGraph."""
174  if inspect_utils.istfmethodtarget(f):
175    return f.__self__.call(args, kwargs)
176
177  return f(*args, **kwargs)
178
179
180def _is_known_loaded_type(f, module_name, entity_name):
181  """Tests whether the function or method is an instance of a known type."""
182  if (module_name not in sys.modules or
183      not hasattr(sys.modules[module_name], entity_name)):
184    return False
185  type_entity = getattr(sys.modules[module_name], entity_name)
186  if isinstance(f, type_entity):
187    # The method if of this type. Example:
188    #
189    # o = ClassType()
190    # function(o.method)()
191    return True
192  if tf_inspect.ismethod(f):
193    f = six.get_unbound_function(f)
194    # The the unbound method if of this type. Example:
195    #
196    # class ClassType:
197    #   @function
198    #   def method(self):
199    #     ...
200    # o = ClassType()
201    # o.method()
202    if isinstance(f, type_entity):
203      return True
204  return False
205
206
207def converted_call(f, owner, options, args, kwargs):
208  """Compiles a function call inline. For internal use only."""
209  logging.log(1,
210              'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n',
211              f, owner, args, kwargs)
212
213  if owner is not None:
214    if not isinstance(f, str):
215      raise ValueError(
216          'When owner is specified, the function name must be specified as'
217          ' a string: {}'.format(f))
218
219    # Special case when the owner is a 'super' object. In that case lookups of
220    # dynamic attributes won't work. See
221    # inspect_utils.SuperWrapperForDynamicAttrs.
222    if isinstance(owner, super):
223      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)
224
225    f = getattr(owner, f)
226
227  if inspect_utils.isbuiltin(f):
228    return py_builtins.overload_of(f)(*args, **kwargs)
229
230  if _is_known_loaded_type(f, 'weakref', 'ref'):
231    logging.log(2, 'Permanently whitelisted: %s: weakref', f)
232    return _call_unconverted(f, args, kwargs)
233
234  # TODO(b/122265385): Remove this bypass.
235  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
236      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
237    logging.warn(
238        'Entity {} appears to be decorated by wrapt, which is not yet supported'
239        ' by AutoGraph. The function will be called without transformation.'
240        ' You may however apply AutoGraph before the decorator.'.format(f))
241    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
242    return _call_unconverted(f, args, kwargs)
243
244  # Constructors are permanently whitelisted.
245  # TODO(mdan): Toggle as experimental feature instead.
246  # TODO(b/124016764): Remove this limitation.
247  if tf_inspect.isclass(f):
248    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
249    return _call_unconverted(f, args, kwargs)
250
251  # Other built-in modules are permanently whitelisted.
252  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
253  # Note: TF linter disallows importing inspect.
254  if any(f in m.__dict__.values()
255         for m in (collections, pdb, copy, tf_inspect._inspect)):  # pylint:disable=protected-access
256    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
257    return _call_unconverted(f, args, kwargs)
258
259  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
260    return _call_unconverted(f, args, kwargs)
261
262  # internal_convert_user_code is for example turned off when issuing a dynamic
263  # call conversion from generated code while in nonrecursive mode. In that
264  # case we evidently don't want to recurse, but we still have to convert
265  # things like builtins.
266  if not options.internal_convert_user_code:
267    return _call_unconverted(f, args, kwargs)
268
269  # TODO(mdan): Move this entire block inside to_graph.
270  try:  # Begin of transformation error guards
271
272    # Unwrap functools.partial objects
273    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
274    while isinstance(f, functools.partial):
275      args = f.args + args
276      new_kwargs = {}
277      if f.keywords is not None:
278        new_kwargs.update(f.keywords)
279      new_kwargs.update(kwargs)
280      kwargs = new_kwargs
281      f = f.func
282
283    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
284      # Regular functions
285      target_entity = f
286      arg_map_target = f
287      f_self = inspect_utils.getmethodself(f)
288
289      # TODO(b/119246461): This may be more elegantly handled using __get__?
290      if f_self is not None:
291        effective_args = (f_self,) + args
292      else:
293        effective_args = args
294
295    elif tf_inspect.isclass(f):
296      # Constructors
297      # Note: Until we support class constructurs, and enable whole-class
298      # conversion with an experimental flag, this branch is dead code.
299      # TODO(mdan): Consider removing unless there is a compelling use case.
300      target_entity = f
301      arg_map_target = f.__init__
302      effective_args = args
303
304    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
305      # Callable objects
306      target_entity = f.__call__
307      arg_map_target = f.__call__
308      effective_args = (f,) + args
309
310    else:
311      target_entity = f
312      raise NotImplementedError('unknown callable type "%s"' % type(f))
313
314    arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
315    arg_types = {}
316    for name, arg in arg_values.items():
317      arg_class = arg.__class__
318      arg_types[name] = (arg_class.__name__, arg_class)
319
320    converted_f = to_graph(
321        target_entity,
322        recursive=options.recursive,
323        arg_values=arg_values,
324        arg_types=arg_types,
325        experimental_optional_features=options.optional_features)
326
327    if logging.has_verbosity(2):
328      logging.log(2, 'Defaults of %s : %s', converted_f,
329                  converted_f.__defaults__)
330      callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
331      formatted_callargs = '\n'.join(
332          '    {}: {}'.format(k, v) for k, v in callargs.items())
333      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
334
335  # TODO(mdan): Reduce this list.
336  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
337          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
338          ValueError, IOError) as e:
339
340    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
341
342    if is_autograph_strict_conversion_mode():
343      raise
344
345    logging.warn(
346        'Entity %s could not be transformed and will be staged without change.'
347        ' Error details can be found in the logs when running with the env'
348        ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
349        ' AutoGraph team. Cause: %s', target_entity, e)
350
351    return _call_unconverted(f, args, kwargs)
352
353  result = converted_f(*effective_args, **kwargs)
354
355  # The converted function's closure is simply inserted into the function's
356  # module __dict__. Since modules are permanently cached, that results in
357  # leaking the entire closure.
358  # Normally, it's not safe to delete the module because that may release said
359  # closure as well. However, in the case of converted_call we are certain the
360  # function will not be executed again, so the closure should no longer be
361  # needed so long as the function doesn't return any executable code.
362  # TODO(mdan): Attach the closure properly, using cells.
363  if all(map(_is_not_callable, nest.flatten(result))):
364    del sys.modules[converted_f.__module__]
365
366  return result
367
368
369def _is_not_callable(obj):
370  # TODO(brianklee): Handle case when obj is a tensor dependent on a py_func.
371  if isinstance(obj, (int, float, complex, str, bool)):
372    return True
373  if isinstance(obj, (np.ndarray, np.generic)):
374    return True
375  if tensor_util.is_tensor(obj):
376    return True
377  return False
378
379
380@tf_export('autograph.to_graph')
381def to_graph(entity,
382             recursive=True,
383             arg_values=None,
384             arg_types=None,
385             experimental_optional_features=converter.Feature.ALL):
386  """Converts a Python entity into a TensorFlow graph.
387
388  Also see: `tf.autograph.to_code`, `tf.function`.
389
390  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
391  Python code to TensorFlow graph code. It does not implement any caching,
392  variable management or create any actual ops, and is best used where greater
393  control over the generated TensorFlow graph is desired. Another difference
394  from `tf.function` is that `to_graph` will not wrap the graph into a
395  TensorFlow function or a Python callable. Internally, `tf.function` uses
396  `to_graph`.
397
398  _Example Usage_
399
400  ```python
401    def foo(x):
402      if x > 0:
403        y = x * x
404      else:
405        y = -x
406      return y
407
408    converted_foo = to_graph(foo)
409
410    x = tf.constant(1)
411    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
412    assert is_tensor(y)
413  ```
414
415  Supported Python entities include:
416    * functions
417    * classes
418    * object methods
419
420  Functions are converted into new functions with converted code.
421
422  Classes are converted by generating a new class whose methods use converted
423  code.
424
425  Methods are converted into unbound function that have an additional first
426  argument called `self`.
427
428  Args:
429    entity: Python callable or class to convert.
430    recursive: Whether to recursively convert any functions that the
431      converted function may call.
432    arg_values: Optional dict of value hints for symbols including
433      function arguments mapping string names to actual values. For example,
434      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
435    arg_types: Optional dict of type hints for symbols including function
436      arguments. Type hints allow specifying just the type of a variable, rather
437      than a specific value.
438    experimental_optional_features: `None`, a tuple of, or a single
439      `tf.autograph.experimental.Feature` value. Controls the use of
440      optional features in the conversion process.
441
442  Returns:
443    Same as `entity`, the converted Python function or class.
444
445  Raises:
446    ValueError: If the entity could not be converted.
447  """
448  try:
449    program_ctx = converter.ProgramContext(
450        options=converter.ConversionOptions(
451            recursive=recursive,
452            optional_features=experimental_optional_features),
453        autograph_module=tf_inspect.getmodule(to_graph))
454    nodes, name, namespace = conversion.entity_to_graph(entity, program_ctx,
455                                                        arg_values, arg_types)
456
457    compiled_module, _ = compiler.ast_to_object(
458        nodes,
459        source_prefix=program_ctx.required_imports,
460        include_source_map=True)
461
462    # The compiled code should see everything the entry entity saw.
463    # TODO(mdan): This might not work well if the call tree spans modules?
464    for key, val in namespace.items():
465      # Avoid overwriting entities that have been transformed.
466      if key not in compiled_module.__dict__:
467        compiled_module.__dict__[key] = val
468    compiled = getattr(compiled_module, name)
469
470    if hasattr(entity, '__defaults__'):
471      logging.log(3, 'Default args mapping: %s has: %s', entity,
472                  entity.__defaults__)
473      compiled.__defaults__ = entity.__defaults__
474    else:
475      logging.log(3, 'Default args mapping: %s has no __defaults__', entity)
476
477    logging.log(3, 'Namespace of %s includes: %s', compiled,
478                compiled_module.__dict__.keys())
479
480    if hasattr(compiled, '__globals__'):
481      # Remove self to avoid circular references. This will probably only work
482      # so long as the function is not reentrant.
483      del compiled.__globals__[name]
484
485    # Need this so the source_mapping attribute is available for the context
486    # manager to access for runtime errors.
487    #
488    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
489    # symbol to the compiled module.
490    # TODO(mdan): Record this statically in the generated code.
491    # TODO(mdan): Rename this attribute to 'autograph_info__'
492    source_map_attribute_name = 'ag_source_map'
493    if getattr(compiled, source_map_attribute_name, None) is not None:
494      # TODO(znado): change input problem errors into TransformError
495      raise ValueError('cannot convert %s because is has an attribute '
496                       '"%s", which is reserved for AutoGraph.' %
497                       (compiled, source_map_attribute_name))
498    setattr(compiled, source_map_attribute_name,
499            compiled_module.__dict__['ag_source_map__'])
500
501    return compiled
502  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
503    errors.report_internal_error(entity, e)
504
505
506@tf_export('autograph.to_code')
507def to_code(entity,
508            recursive=True,
509            arg_values=None,
510            arg_types=None,
511            indentation='  ',
512            experimental_optional_features=converter.Feature.ALL):
513  """Similar to `to_graph`, but returns Python source code as a string.
514
515  Also see: `tf.autograph.to_graph`.
516
517  `to_graph` returns the Python source code that can be used to generate a
518  TensorFlow graph that is functionally identical to the input Python code.
519
520  Args:
521    entity: Python callable or class to convert.
522    recursive: Whether to recursively convert any functions that the
523      converted function may call.
524    arg_values: Optional dict of value hints for symbols including
525      function arguments mapping string names to actual values. For example,
526      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
527    arg_types: Optional dict of type hints for symbols including function
528      arguments. Type hints allow specifying just the type of a variable, rather
529      than a specific value.
530    indentation: The string to use for indenting. Typically two or four spaces,
531      or just the tab character.
532    experimental_optional_features: `None`, a tuple of, or a single
533      `tf.autograph.experimental.Feature` value. Controls the use of
534      optional features in the conversion process.
535
536  Returns:
537    The converted code as string.
538  """
539  program_ctx = converter.ProgramContext(
540      options=converter.ConversionOptions(
541          recursive=recursive,
542          optional_features=experimental_optional_features),
543      autograph_module=tf_inspect.getmodule(to_graph))
544  nodes, _, _ = conversion.entity_to_graph(entity, program_ctx, arg_values,
545                                           arg_types)
546
547  code = compiler.ast_to_source(nodes, indentation)
548
549  return program_ctx.required_imports + '\n\n' + code
550