1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Container for origin source code information before AutoGraph compilation.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import difflib 22import os 23import tokenize 24 25import gast 26import six 27 28from tensorflow.python.autograph.pyct import anno 29from tensorflow.python.autograph.pyct import ast_util 30from tensorflow.python.autograph.pyct import parser 31from tensorflow.python.autograph.pyct import pretty_printer 32from tensorflow.python.autograph.utils import ag_logging as logging 33from tensorflow.python.util import tf_inspect 34 35 36class LineLocation( 37 collections.namedtuple('LineLocation', ('filename', 'lineno'))): 38 """Similar to Location, but without column information. 39 40 Attributes: 41 filename: Text 42 lineno: int, 1-based 43 """ 44 pass 45 46 47class Location( 48 collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))): 49 """Encodes code location information. 50 51 Attributes: 52 filename: Text 53 lineno: int, 1-based 54 col_offset: int 55 line_loc: LineLocation 56 """ 57 58 @property 59 def line_loc(self): 60 return LineLocation(self.filename, self.lineno) 61 62 63class OriginInfo( 64 collections.namedtuple( 65 'OriginInfo', 66 ('loc', 'function_name', 'source_code_line', 'comment'))): 67 """Container for information about the source code before conversion. 68 69 Attributes: 70 loc: Location 71 function_name: Optional[Text] 72 source_code_line: Text 73 comment: Optional[Text] 74 """ 75 76 def as_frame(self): 77 """Returns a 4-tuple consistent with the return of traceback.extract_tb.""" 78 return (self.loc.filename, self.loc.lineno, self.function_name, 79 self.source_code_line) 80 81 def __repr__(self): 82 if self.loc.filename: 83 return '{}:{}:{}'.format( 84 os.path.split(self.loc.filename)[1], self.loc.lineno, 85 self.loc.col_offset) 86 return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset) 87 88 89# TODO(mdan): This source map should be a class - easier to refer to. 90def create_source_map(nodes, code, filename, indices_in_code): 91 """Creates a source map between an annotated AST and the code it compiles to. 92 93 Args: 94 nodes: Iterable[ast.AST, ...] 95 code: Text 96 filename: Optional[Text] 97 indices_in_code: Union[int, Iterable[int, ...]], the positions at which 98 nodes appear in code. The parser always returns a module when parsing 99 code. This argument indicates the position in that module's body at 100 which the corresponding of node should appear. 101 102 Returns: 103 Dict[LineLocation, OriginInfo], mapping locations in code to locations 104 indicated by origin annotations in node. 105 """ 106 reparsed_nodes = parser.parse_str(code) 107 reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] 108 for node in reparsed_nodes: 109 resolve(node, code) 110 111 result = {} 112 113 try: 114 for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): 115 # Note: generated code might not be mapped back to its origin. 116 # TODO(mdan): Generated code should always be mapped to something. 117 origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) 118 final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) 119 if origin_info is None or final_info is None: 120 continue 121 122 line_loc = LineLocation(filename, final_info.loc.lineno) 123 124 existing_origin = result.get(line_loc) 125 if existing_origin is not None: 126 # Overlaps may exist because of child nodes, but almost never to 127 # different line locations. Exception make decorated functions, where 128 # both lines are mapped to the same line in the AST. 129 130 # Line overlaps: keep bottom node. 131 if existing_origin.loc.line_loc == origin_info.loc.line_loc: 132 if existing_origin.loc.lineno >= origin_info.loc.lineno: 133 continue 134 135 # In case of overlaps, keep the leftmost node. 136 if existing_origin.loc.col_offset <= origin_info.loc.col_offset: 137 continue 138 139 result[line_loc] = origin_info 140 except ValueError: 141 if logging.has_verbosity(3): 142 for n, rn in zip(nodes, reparsed_nodes): 143 nodes_str = pretty_printer.fmt(n, color=False, noanno=True) 144 reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True) 145 diff = difflib.context_diff( 146 nodes_str.split('\n'), 147 reparsed_nodes_str.split('\n'), 148 fromfile='Original nodes', 149 tofile='Reparsed nodes', 150 n=7) 151 diff = '\n'.join(diff) 152 logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff) 153 raise 154 155 return result 156 157 158# TODO(znado): Consider refactoring this into a Visitor. 159# TODO(mdan): Does this work correctly with inner functions? 160def resolve(node, source, function=None): 161 """Adds an origin information to node and its subnodes. 162 163 This allows us to map the original source code line numbers to generated 164 source code. 165 166 Args: 167 node: gast.AST node. Should be a gast.FunctionDef. This is the node we 168 annotate with origin information. 169 source: Text, the source code. Should satisfy relationship 170 `node in iter_tree(gast.parse(source))`; otherwise the lineno will be 171 unreliable. 172 function: The original function. If it is None then only the line numbers 173 and column offset will be set in the annotation, with the rest of the 174 information being None. 175 """ 176 if function: 177 _, function_lineno = tf_inspect.getsourcelines(function) 178 function_filepath = tf_inspect.getsourcefile(function) 179 else: 180 function_lineno = None 181 function_filepath = None 182 183 # TODO(mdan): Pull this to a separate utility. 184 code_reader = six.StringIO(source) 185 comment_map = {} 186 for token in tokenize.generate_tokens(code_reader.readline): 187 tok_type, tok_string, loc, _, _ = token 188 srow, _ = loc 189 if tok_type == tokenize.COMMENT: 190 comment_map[srow] = tok_string.strip()[1:].strip() 191 192 source_lines = source.split('\n') 193 for n in gast.walk(node): 194 if not hasattr(n, 'lineno'): 195 continue 196 197 within_body_offset = n.lineno - node.lineno 198 199 source_code_line = source_lines[n.lineno - 1] 200 if function: 201 source_lineno = function_lineno + within_body_offset 202 function_name = function.__name__ 203 else: 204 source_lineno = n.lineno 205 function_name = None 206 207 location = Location(function_filepath, source_lineno, n.col_offset) 208 origin = OriginInfo(location, function_name, 209 source_code_line, comment_map.get(source_lineno)) 210 anno.setanno(n, anno.Basic.ORIGIN, origin) 211