• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core conversion logic, serves as main point of access."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import imp
23import unittest
24
25import gast
26
27from tensorflow.python.autograph import operators
28from tensorflow.python.autograph import utils
29from tensorflow.python.autograph.converters import arg_defaults
30from tensorflow.python.autograph.converters import asserts
31from tensorflow.python.autograph.converters import break_statements
32from tensorflow.python.autograph.converters import builtin_functions
33from tensorflow.python.autograph.converters import call_trees
34from tensorflow.python.autograph.converters import conditional_expressions
35from tensorflow.python.autograph.converters import continue_statements
36from tensorflow.python.autograph.converters import control_flow
37from tensorflow.python.autograph.converters import directives
38from tensorflow.python.autograph.converters import error_handlers
39from tensorflow.python.autograph.converters import function_scopes
40from tensorflow.python.autograph.converters import lists
41from tensorflow.python.autograph.converters import logical_expressions
42from tensorflow.python.autograph.converters import return_statements
43from tensorflow.python.autograph.converters import side_effect_guards
44from tensorflow.python.autograph.converters import slices
45from tensorflow.python.autograph.core import config
46from tensorflow.python.autograph.core import converter
47from tensorflow.python.autograph.core import errors as ag_errors
48from tensorflow.python.autograph.core import function_wrapping
49from tensorflow.python.autograph.core import naming
50from tensorflow.python.autograph.core import unsupported_features_checker
51from tensorflow.python.autograph.lang import special_functions
52from tensorflow.python.autograph.pyct import ast_util
53from tensorflow.python.autograph.pyct import compiler
54from tensorflow.python.autograph.pyct import errors
55from tensorflow.python.autograph.pyct import inspect_utils
56from tensorflow.python.autograph.pyct import origin_info
57from tensorflow.python.autograph.pyct import parser
58from tensorflow.python.autograph.pyct import pretty_printer
59from tensorflow.python.autograph.pyct import qual_names
60from tensorflow.python.autograph.pyct import templates
61from tensorflow.python.autograph.pyct import transformer
62from tensorflow.python.autograph.utils import ag_logging as logging
63from tensorflow.python.util import tf_inspect
64
65
66# TODO(mdan): Might we not need any renaming at all?
67
68
69def is_whitelisted_for_graph(o):
70  """Check whether an entity is whitelisted for use in graph mode.
71
72  Examples of whitelisted entities include all members of the tensorflow
73  package.
74
75  Args:
76    o: A Python entity.
77  Returns:
78    Boolean
79  """
80  # TODO(b/120224672): Fix this.
81  if isinstance(o, functools.partial):
82    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
83    # functools.partial objects do not have a __module__ attribute.
84    m = functools
85  else:
86    m = tf_inspect.getmodule(o)
87
88  if hasattr(m, '__name__'):
89    # Builtins typically have unnamed modules.
90    for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
91      if m.__name__.startswith(prefix):
92        logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
93        return True
94
95    # Temporary -- whitelist tensorboard modules.
96    # TODO(b/122731813): Remove.
97    if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
98      logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
99      return True
100
101  if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
102    logging.log(2, 'Whitelisted: %s: already converted', o)
103    return True
104
105  if hasattr(o, '__call__'):
106    # Callable objects: whitelisted if their __call__ method is.
107    # The type check avoids infinite recursion around the __call__ method
108    # of function objects.
109    if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__):  # pylint: disable=unidiomatic-typecheck
110      logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
111      return True
112
113  owner_class = None
114  if tf_inspect.ismethod(o):
115    # Methods of whitelisted classes are also whitelisted, even if they are
116    # bound via user subclasses.
117    #
118    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
119    # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
120    # whitelisted.
121    #
122    #   class Custom(tf.Foo):
123    #     pass
124    #
125    #   baz = Custom()
126    #
127    # For the example above, if `Custom` did overload `bar`, then it would no
128    # longer be whitelisted.
129
130    owner_class = inspect_utils.getmethodclass(o)
131    if owner_class is not None:
132      if issubclass(owner_class, unittest.TestCase):
133        logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
134        return True
135
136      owner_class = inspect_utils.getdefiningclass(o, owner_class)
137      if is_whitelisted_for_graph(owner_class):
138        logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
139                    owner_class)
140        return True
141
142  if inspect_utils.isnamedtuple(o):
143    # Due to the way they're constructed, namedtuple types cannot be converted
144    # because they don't expose source code. But we assume they are safe for
145    # graph mode since they are just containers.
146    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
147      logging.warn(
148          'Entity {} looks like a namedtuple subclass. Its constructor will'
149          ' not be converted by AutoGraph, but if it has any custom methods,'
150          ' those will be.'.format(o), 1)
151    logging.log(2, 'Whitelisted: %s: named tuple', o)
152    return True
153
154  logging.log(2, 'Not whitelisted: %s: default rule', o)
155  return False
156
157
158def entity_to_graph(o, program_ctx, arg_values, arg_types):
159  """Compile a Python entity into equivalent TensorFlow.
160
161  The function will also recursively compile all the entities that `o`
162  references, updating `dependency_cache`.
163
164  This function is reentrant, and relies on dependency_cache to avoid
165  generating duplicate code.
166
167  Args:
168    o: A Python entity.
169    program_ctx: A ProgramContext object.
170    arg_values: A dict containing value hints for symbols like function
171        parameters.
172    arg_types: A dict containing type hints for symbols like function
173        parameters.
174
175  Returns:
176    A tuple (ast, new_name, namespace):
177        * ast: An AST representing an entity with interface equivalent to `o`,
178            but which when executed it creates TF a graph.
179        * new_name: The symbol name under which the new entity can be found.
180        * namespace: A dict mapping all symbols visible to the converted entity,
181            keyed by their symbol name.
182
183  Raises:
184    ValueError: if the entity type is not supported.
185  """
186  logging.log(1, 'Converting %s', o)
187
188  if tf_inspect.isclass(o):
189    nodes, name, ns = class_to_graph(o, program_ctx)
190  elif tf_inspect.isfunction(o):
191    nodes, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
192  elif tf_inspect.ismethod(o):
193    nodes, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
194  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
195  elif hasattr(o, '__class__'):
196    raise NotImplementedError(
197        'Object conversion is not yet supported. If you are '
198        'trying to convert code that uses an existing object, '
199        'try including the creation of that object in the '
200        'conversion. For example, instead of converting the method '
201        'of a class, try converting the entire class instead. '
202        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
203        'python/autograph/README.md#using-the-functional-api '
204        'for more information.')
205  else:
206    raise ValueError(
207        'Entity "%s" has unsupported type "%s". Only functions and classes are '
208        'supported for now.' % (o, type(o)))
209
210  # TODO(mdan): This is temporary. it should be created using a converter.
211  # TODO(mdan): The attribute should be added with a helper, not directly.
212  # The helper can ensure there are no collisions.
213  template = '''
214      entity.autograph_info__ = {}
215  '''
216  nodes.extend(templates.replace(template, entity=name))
217
218  if logging.has_verbosity(2):
219    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
220                compiler.ast_to_source(nodes))
221  if logging.has_verbosity(4):
222    for n in nodes:
223      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
224                  pretty_printer.fmt(n, color=False))
225
226  return nodes, name, ns
227
228
229def class_to_graph(c, program_ctx):
230  """Specialization of `entity_to_graph` for classes."""
231  # TODO(mdan): Revisit this altogether. Not sure we still need it.
232  converted_members = {}
233  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
234  members = tf_inspect.getmembers(c, predicate=method_filter)
235  if not members:
236    raise ValueError('Cannot convert %s: it has no member methods.' % c)
237
238  class_namespace = {}
239  for _, m in members:
240    # Only convert the members that are directly defined by the class.
241    if inspect_utils.getdefiningclass(m, c) is not c:
242      continue
243    nodes, _, namespace = function_to_graph(
244        m,
245        program_ctx=program_ctx,
246        arg_values={},
247        arg_types={'self': (c.__name__, c)},
248        do_rename=False)
249    if class_namespace is None:
250      class_namespace = namespace
251    else:
252      class_namespace.update(namespace)
253    converted_members[m] = nodes[0]
254  namer = naming.Namer(class_namespace)
255  class_name = namer.class_name(c.__name__)
256
257  # Process any base classes: if the superclass if of a whitelisted type, an
258  # absolute import line is generated.
259  output_nodes = []
260  renames = {}
261  base_names = []
262  for base in c.__bases__:
263    if isinstance(object, base):
264      base_names.append('object')
265      continue
266    if is_whitelisted_for_graph(base):
267      alias = namer.new_symbol(base.__name__, ())
268      output_nodes.append(
269          gast.ImportFrom(
270              module=base.__module__,
271              names=[gast.alias(name=base.__name__, asname=alias)],
272              level=0))
273    else:
274      raise NotImplementedError(
275          'Conversion of classes that do not directly extend classes from'
276          ' whitelisted modules is temporarily suspended. If this breaks'
277          ' existing code please notify the AutoGraph team immediately.')
278    base_names.append(alias)
279    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
280
281  # Generate the definition of the converted class.
282  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
283  class_def = gast.ClassDef(
284      class_name,
285      bases=bases,
286      keywords=[],
287      body=list(converted_members.values()),
288      decorator_list=[])
289  # Make a final pass to replace references to the class or its base classes.
290  # Most commonly, this occurs when making super().__init__() calls.
291  # TODO(mdan): Making direct references to superclass' superclass will fail.
292  class_def = qual_names.resolve(class_def)
293  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
294  class_def = ast_util.rename_symbols(class_def, renames)
295
296  output_nodes.append(class_def)
297
298  return output_nodes, class_name, class_namespace
299
300
301def _add_reserved_symbol(namespace, name, entity):
302  if name not in namespace:
303    namespace[name] = entity
304  elif namespace[name] != entity:
305    raise ValueError('The name "%s" is reserved and may not be used.' % name)
306
307
308ag_internal = None
309
310
311# TODO(mdan): Move into core or replace with an actual importable module.
312def _add_self_references(namespace, autograph_module):
313  """Adds namespace references to the module that exposes the api itself."""
314  global ag_internal
315  if ag_internal is None:
316    # Craft a module that exposes parts of the external API as well as certain
317    # internal modules.
318    ag_internal = imp.new_module('autograph')
319    ag_internal.__dict__.update(autograph_module.__dict__)
320    ag_internal.ConversionOptions = converter.ConversionOptions
321    ag_internal.Feature = converter.Feature
322    ag_internal.utils = utils
323    ag_internal.function_scope = function_wrapping.function_scope
324    ag_internal.rewrite_graph_construction_error = (
325        ag_errors.rewrite_graph_construction_error)
326    # TODO(mdan): Add safeguards against name clashes.
327    # We don't want to create a submodule because we want the operators to be
328    # accessible as ag__.<operator>
329    ag_internal.__dict__.update(special_functions.__dict__)
330    ag_internal.__dict__.update(operators.__dict__)
331
332  _add_reserved_symbol(namespace, 'ag__', ag_internal)
333
334
335def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
336  """Specialization of `entity_to_graph` for callable functions."""
337
338  node, source, _ = parser.parse_entity(f)
339  logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
340
341  # In general, the output of inspect.getsource is inexact for lambdas because
342  # it uses regex matching to adjust the exact location around the line number
343  # that CPython records. Then, the entire containing line is returned, which
344  # we may have trouble disambiguating. For example:
345  # x, y = lambda: 1, lambda: 2
346  if f.__name__ == '<lambda>':
347    nodes = ast_util.find_matching_definitions(node, f)
348    if len(nodes) != 1:
349      raise ValueError(
350          'Unable to identify source code of lambda function {}. It was'
351          ' defined on this line: {}, which must contain a single lambda with'
352          ' matching signature. To avoid ambiguity, define each lambda'
353          ' in a separate expression.'.format(f, source))
354    node, = nodes
355
356  # TODO(znado): Place inside standard_analysis.
357  origin_info.resolve(node, source, f)
358  namespace = inspect_utils.getnamespace(f)
359  _add_self_references(namespace, program_ctx.autograph_module)
360  namer = naming.Namer(namespace)
361
362  entity_info = transformer.EntityInfo(
363      source_code=source,
364      source_file='<fragment>',
365      namespace=namespace,
366      arg_values=arg_values,
367      arg_types=arg_types)
368  context = converter.EntityContext(namer, entity_info, program_ctx)
369  try:
370    node = node_to_graph(node, context)
371  except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
372    logging.error(1, 'Error converting %s', f, exc_info=True)
373    raise errors.InternalError('conversion', e)
374    # TODO(mdan): Catch and rethrow syntax errors.
375
376  if isinstance(node, gast.Lambda):
377    new_name = namer.new_symbol('tf__lambda', ())
378    node = gast.Assign(
379        targets=[gast.Name(new_name, gast.Store(), None)], value=node)
380
381  elif do_rename:
382    # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
383    new_name = namer.function_name(f.__name__)
384    node.name = new_name
385  else:
386    new_name = f.__name__
387    assert node.name == new_name
388
389  return [node], new_name, namespace
390
391
392def node_to_graph(node, context):
393  """Convert Python code to equivalent TF graph mode code.
394
395  Args:
396    node: AST, the code to convert.
397    context: converter.EntityContext
398
399  Returns:
400    A tuple (node, deps):
401        * node: A Python ast node, representing the converted code.
402        * deps: A set of strings, the fully qualified names of entity
403            dependencies that this node has.
404  """
405  # TODO(mdan): Insert list_comprehensions somewhere.
406  unsupported_features_checker.verify(node)
407
408  node = converter.standard_analysis(node, context, is_initial=True)
409  # Past this point, line numbers are no longer accurate so we ignore the
410  # source.
411  # TODO(mdan): Is it feasible to reconstruct intermediate source code?
412  context.info.source_code = None
413  node = converter.apply_(node, context, arg_defaults)
414  node = converter.apply_(node, context, directives)
415  node = converter.apply_(node, context, break_statements)
416  if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS):
417    node = converter.apply_(node, context, asserts)
418  # Note: sequencing continue canonicalization before for loop one avoids
419  # dealing with the extra loop increment operation that the for
420  # canonicalization creates.
421  node = converter.apply_(node, context, continue_statements)
422  node = converter.apply_(node, context, return_statements)
423  if context.program.options.uses(converter.Feature.LISTS):
424    node = converter.apply_(node, context, lists)
425    node = converter.apply_(node, context, slices)
426  if context.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS):
427    node = converter.apply_(node, context, builtin_functions)
428  node = converter.apply_(node, context, call_trees)
429  node = converter.apply_(node, context, control_flow)
430  node = converter.apply_(node, context, conditional_expressions)
431  if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS):
432    node = converter.apply_(node, context, logical_expressions)
433  if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
434    node = converter.apply_(node, context, side_effect_guards)
435  # TODO(mdan): If function scopes ever does more, the toggle will need moving.
436  if context.program.options.uses(converter.Feature.NAME_SCOPES):
437    node = converter.apply_(node, context, function_scopes)
438  if context.program.options.uses(converter.Feature.ERROR_REWRITING):
439    node = converter.apply_(node, context, error_handlers)
440  return node
441