• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Container for origin source code information before AutoGraph compilation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import difflib
22import os
23import tokenize
24
25import gast
26import six
27
28from tensorflow.python.autograph.pyct import anno
29from tensorflow.python.autograph.pyct import ast_util
30from tensorflow.python.autograph.pyct import parser
31from tensorflow.python.autograph.pyct import pretty_printer
32from tensorflow.python.autograph.utils import ag_logging as logging
33from tensorflow.python.util import tf_inspect
34
35
36class LineLocation(
37    collections.namedtuple('LineLocation', ('filename', 'lineno'))):
38  """Similar to Location, but without column information.
39
40  Attributes:
41    filename: Text
42    lineno: int, 1-based
43  """
44  pass
45
46
47class Location(
48    collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
49  """Encodes code location information.
50
51  Attributes:
52    filename: Text
53    lineno: int, 1-based
54    col_offset: int
55    line_loc: LineLocation
56  """
57
58  @property
59  def line_loc(self):
60    return LineLocation(self.filename, self.lineno)
61
62
63class OriginInfo(
64    collections.namedtuple(
65        'OriginInfo',
66        ('loc', 'function_name', 'source_code_line', 'comment'))):
67  """Container for information about the source code before conversion.
68
69  Attributes:
70    loc: Location
71    function_name: Optional[Text]
72    source_code_line: Text
73    comment: Optional[Text]
74  """
75
76  def as_frame(self):
77    """Returns a 4-tuple consistent with the return of traceback.extract_tb."""
78    return (self.loc.filename, self.loc.lineno, self.function_name,
79            self.source_code_line)
80
81  def __repr__(self):
82    if self.loc.filename:
83      return '{}:{}:{}'.format(
84          os.path.split(self.loc.filename)[1], self.loc.lineno,
85          self.loc.col_offset)
86    return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset)
87
88
89# TODO(mdan): This source map should be a class - easier to refer to.
90def create_source_map(nodes, code, filename, indices_in_code):
91  """Creates a source map between an annotated AST and the code it compiles to.
92
93  Args:
94    nodes: Iterable[ast.AST, ...]
95    code: Text
96    filename: Optional[Text]
97    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
98        nodes appear in code. The parser always returns a module when parsing
99        code. This argument indicates the position in that module's body at
100        which the corresponding of node should appear.
101
102  Returns:
103    Dict[LineLocation, OriginInfo], mapping locations in code to locations
104    indicated by origin annotations in node.
105  """
106  reparsed_nodes = parser.parse_str(code)
107  reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]
108  for node in reparsed_nodes:
109    resolve(node, code)
110
111  result = {}
112
113  try:
114    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
115      # Note: generated code might not be mapped back to its origin.
116      # TODO(mdan): Generated code should always be mapped to something.
117      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
118      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
119      if origin_info is None or final_info is None:
120        continue
121
122      line_loc = LineLocation(filename, final_info.loc.lineno)
123
124      existing_origin = result.get(line_loc)
125      if existing_origin is not None:
126        # Overlaps may exist because of child nodes, but almost never to
127        # different line locations. Exception make decorated functions, where
128        # both lines are mapped to the same line in the AST.
129
130        # Line overlaps: keep bottom node.
131        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
132          if existing_origin.loc.lineno >= origin_info.loc.lineno:
133            continue
134
135        # In case of overlaps, keep the leftmost node.
136        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
137          continue
138
139      result[line_loc] = origin_info
140  except ValueError:
141    if logging.has_verbosity(3):
142      for n, rn in zip(nodes, reparsed_nodes):
143        nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
144        reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
145        diff = difflib.context_diff(
146            nodes_str.split('\n'),
147            reparsed_nodes_str.split('\n'),
148            fromfile='Original nodes',
149            tofile='Reparsed nodes',
150            n=7)
151        diff = '\n'.join(diff)
152        logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
153    raise
154
155  return result
156
157
158# TODO(znado): Consider refactoring this into a Visitor.
159# TODO(mdan): Does this work correctly with inner functions?
160def resolve(node, source, function=None):
161  """Adds an origin information to node and its subnodes.
162
163  This allows us to map the original source code line numbers to generated
164  source code.
165
166  Args:
167    node: gast.AST node. Should be a gast.FunctionDef. This is the node we
168        annotate with origin information.
169    source: Text, the source code. Should satisfy relationship
170        `node in iter_tree(gast.parse(source))`; otherwise the lineno will be
171        unreliable.
172    function: The original function. If it is None then only the line numbers
173        and column offset will be set in the annotation, with the rest of the
174        information being None.
175  """
176  if function:
177    _, function_lineno = tf_inspect.getsourcelines(function)
178    function_filepath = tf_inspect.getsourcefile(function)
179  else:
180    function_lineno = None
181    function_filepath = None
182
183  # TODO(mdan): Pull this to a separate utility.
184  code_reader = six.StringIO(source)
185  comment_map = {}
186  for token in tokenize.generate_tokens(code_reader.readline):
187    tok_type, tok_string, loc, _, _ = token
188    srow, _ = loc
189    if tok_type == tokenize.COMMENT:
190      comment_map[srow] = tok_string.strip()[1:].strip()
191
192  source_lines = source.split('\n')
193  for n in gast.walk(node):
194    if not hasattr(n, 'lineno'):
195      continue
196
197    within_body_offset = n.lineno - node.lineno
198
199    source_code_line = source_lines[n.lineno - 1]
200    if function:
201      source_lineno = function_lineno + within_body_offset
202      function_name = function.__name__
203    else:
204      source_lineno = n.lineno
205      function_name = None
206
207    location = Location(function_filepath, source_lineno, n.col_offset)
208    origin = OriginInfo(location, function_name,
209                        source_code_line, comment_map.get(source_lineno))
210    anno.setanno(n, anno.Basic.ORIGIN, origin)
211