• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converting AST to code.
16
17Adapted from Tangent.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24# TODO(mdan): Use six for compatibility here.
25import atexit
26import imp
27import os
28import tempfile
29
30import astor
31import gast
32
33from tensorflow.python.autograph.pyct import origin_info
34
35
36def ast_to_source(node, indentation='  '):
37  """Return the source code of given AST.
38
39  Args:
40    node: The code to compile, as an AST object.
41    indentation: The string to use for indentation.
42
43  Returns:
44    code: The source code generated from the AST object
45    source_mapping: A mapping between the user and AutoGraph generated code.
46  """
47  if not isinstance(node, (list, tuple)):
48    node = (node,)
49  generator = astor.code_gen.SourceGenerator(indentation, False,
50                                             astor.string_repr.pretty_string)
51
52  for n in node:
53    if isinstance(n, gast.AST):
54      n = gast.gast_to_ast(n)
55    generator.visit(n)
56    generator.result.append('\n')
57
58  # In some versions of Python, literals may appear as actual values. This
59  # ensures everything is string.
60  code = ''.join(map(str, generator.result))
61
62  # Strip leading blank lines.
63  code_lines = code.split('\n')
64  trimmed_code_lines = []
65  for l in code_lines:
66    if l.rstrip() or trimmed_code_lines:
67      trimmed_code_lines.append(l)
68  code = '\n'.join(trimmed_code_lines)
69
70  # Work around the reference cycle generated by astor.
71  # See https://github.com/berkerpeksag/astor/blob/55dd323f7d8d696610c703c0296763c567685c31/astor/code_gen.py#L162  # pylint:disable=line-too-long
72  # Reference cycles are quite disliked by TensorFlow's tests.
73  if hasattr(generator, 'write'):
74    generator.write = None
75  del generator
76
77  return code
78
79
80def _source_to_module(source, delete_on_exit):
81  with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
82    module_name = os.path.basename(f.name[:-3])
83    f.write(source)
84
85  # TODO(mdan): Try flush() and delete=False instead.
86  if delete_on_exit:
87    atexit.register(lambda: os.remove(f.name))
88  return imp.load_source(module_name, f.name), f.name
89
90
91def ast_to_object(nodes,
92                  indentation='  ',
93                  include_source_map=False,
94                  source_prefix=None,
95                  delete_on_exit=True):
96  """Return the Python objects represented by given AST.
97
98  Compiling the AST code this way ensures that the source code is readable by
99  e.g. `pdb` or `inspect`.
100
101  Args:
102    nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
103        object.
104    indentation: Text, the string to use for indentation.
105    include_source_map: bool, whether to attach a source map to the compiled
106        object. Also see origin_info.py.
107    source_prefix: Optional[Text], string to print as-is into the source file.
108    delete_on_exit: bool, whether to delete the temporary file used for
109        compilation on exit.
110
111  Returns:
112    (module, source): A compiled module, and the source code of the module.
113  Raises:
114    ValueError: If ag_source_map__ is already in the namespace of the compiled
115    nodes.
116  """
117  if not isinstance(nodes, (list, tuple)):
118    nodes = (nodes,)
119
120  source = ast_to_source(nodes, indentation=indentation)
121
122  if source_prefix:
123    source = source_prefix + '\n' + source
124
125  module, filename = _source_to_module(source, delete_on_exit)
126
127  if include_source_map:
128    if isinstance(nodes, (list, tuple)):
129      indices = range(-len(nodes), 0)
130    else:
131      indices = (-1,)
132
133    source_map = origin_info.create_source_map(nodes, source, filename, indices)
134
135    # TODO(znado): Clean this up so we don't need to attach it to the namespace.
136    # We cannot get the rewritten function name until it is too late so
137    # templating is hard, and this cleanly fixes the issues encountered with
138    # nested functions because this is attached to the outermost one.
139    # TODO(mdan): This name should be decided by the caller.
140    source_map_name = 'ag_source_map__'
141    assert source_map_name not in module.__dict__, (
142        'cannot convert %s because is has namespace attribute "%s", which is '
143        'reserved for AutoGraph.') % (module, source_map_name)
144    module.__dict__[source_map_name] = source_map
145
146  return module, source
147