1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converting AST to code. 16 17Adapted from Tangent. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24# TODO(mdan): Use six for compatibility here. 25import atexit 26import imp 27import os 28import tempfile 29 30import astor 31import gast 32 33from tensorflow.python.autograph.pyct import origin_info 34 35 36def ast_to_source(node, indentation=' '): 37 """Return the source code of given AST. 38 39 Args: 40 node: The code to compile, as an AST object. 41 indentation: The string to use for indentation. 42 43 Returns: 44 code: The source code generated from the AST object 45 source_mapping: A mapping between the user and AutoGraph generated code. 46 """ 47 if not isinstance(node, (list, tuple)): 48 node = (node,) 49 generator = astor.code_gen.SourceGenerator(indentation, False, 50 astor.string_repr.pretty_string) 51 52 for n in node: 53 if isinstance(n, gast.AST): 54 n = gast.gast_to_ast(n) 55 generator.visit(n) 56 generator.result.append('\n') 57 58 # In some versions of Python, literals may appear as actual values. This 59 # ensures everything is string. 60 code = ''.join(map(str, generator.result)) 61 62 # Strip leading blank lines. 63 code_lines = code.split('\n') 64 trimmed_code_lines = [] 65 for l in code_lines: 66 if l.rstrip() or trimmed_code_lines: 67 trimmed_code_lines.append(l) 68 code = '\n'.join(trimmed_code_lines) 69 70 # Work around the reference cycle generated by astor. 71 # See https://github.com/berkerpeksag/astor/blob/55dd323f7d8d696610c703c0296763c567685c31/astor/code_gen.py#L162 # pylint:disable=line-too-long 72 # Reference cycles are quite disliked by TensorFlow's tests. 73 if hasattr(generator, 'write'): 74 generator.write = None 75 del generator 76 77 return code 78 79 80def _source_to_module(source, delete_on_exit): 81 with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: 82 module_name = os.path.basename(f.name[:-3]) 83 f.write(source) 84 85 # TODO(mdan): Try flush() and delete=False instead. 86 if delete_on_exit: 87 atexit.register(lambda: os.remove(f.name)) 88 return imp.load_source(module_name, f.name), f.name 89 90 91def ast_to_object(nodes, 92 indentation=' ', 93 include_source_map=False, 94 source_prefix=None, 95 delete_on_exit=True): 96 """Return the Python objects represented by given AST. 97 98 Compiling the AST code this way ensures that the source code is readable by 99 e.g. `pdb` or `inspect`. 100 101 Args: 102 nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST 103 object. 104 indentation: Text, the string to use for indentation. 105 include_source_map: bool, whether to attach a source map to the compiled 106 object. Also see origin_info.py. 107 source_prefix: Optional[Text], string to print as-is into the source file. 108 delete_on_exit: bool, whether to delete the temporary file used for 109 compilation on exit. 110 111 Returns: 112 (module, source): A compiled module, and the source code of the module. 113 Raises: 114 ValueError: If ag_source_map__ is already in the namespace of the compiled 115 nodes. 116 """ 117 if not isinstance(nodes, (list, tuple)): 118 nodes = (nodes,) 119 120 source = ast_to_source(nodes, indentation=indentation) 121 122 if source_prefix: 123 source = source_prefix + '\n' + source 124 125 module, filename = _source_to_module(source, delete_on_exit) 126 127 if include_source_map: 128 if isinstance(nodes, (list, tuple)): 129 indices = range(-len(nodes), 0) 130 else: 131 indices = (-1,) 132 133 source_map = origin_info.create_source_map(nodes, source, filename, indices) 134 135 # TODO(znado): Clean this up so we don't need to attach it to the namespace. 136 # We cannot get the rewritten function name until it is too late so 137 # templating is hard, and this cleanly fixes the issues encountered with 138 # nested functions because this is attached to the outermost one. 139 # TODO(mdan): This name should be decided by the caller. 140 source_map_name = 'ag_source_map__' 141 assert source_map_name not in module.__dict__, ( 142 'cannot convert %s because is has namespace attribute "%s", which is ' 143 'reserved for AutoGraph.') % (module, source_map_name) 144 module.__dict__[source_map_name] = source_map 145 146 return module, source 147