• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code transformation exceptions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.autograph.pyct import origin_info
24
25
26class FrameInfo(
27    collections.namedtuple('FrameInfo',
28                           ('filename', 'lineno', 'function_name', 'code',
29                            'is_converted', 'is_allowlisted'))):
30
31  __slots__ = ()
32
33
34def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
35  """Summarizes inner traceback frames up to the call to a given function.
36
37  This functions locates the innermost (i.e. most recent) frame that corresponds
38  to code that can be mapped by source_map originated from, and returns a
39  translated stack trace ending at that frame. If no such frame is found, the
40  entire stack trace is summarized.
41
42  For example, the following code:
43
44    def f():
45      for i in tf.range(1):
46        z = y + i  # z only defined here
47
48  Would generate this traceback:
49
50    <converted code>
51        ag__.for_stmt(...)
52    <for_stmt>
53        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
54    <_known_len_tf_for_stmt>
55        _disallow_undefs_into_loop(*init_state)
56    <_disallow_undefs_into_loop>
57        raise ...
58
59  Which is then processed into:
60
61    <f>
62        for i in tf.range(1):
63    <for_stmt>
64        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
65    <_known_len_tf_for_stmt>
66        _disallow_undefs_into_loop(*init_state)
67    <_disallow_undefs_into_loop>
68        raise ...
69
70  Args:
71    tb: traceback.FrameSummary, The traceback corresponding to an error.
72      Typically, the output of traceback.Summary.extract(capture_locals=True).
73    source_map: Dict[LineLocation, OriginInfo], a source map as created by
74      origin_info.create_source_map.
75    converter_filename: str, the file path of the converted module. Call frames
76      corresponding to this module are elided and their preceding frames are
77      marked as allowlisted. Note that frames enclosing converted code are
78      dropped using a different mechanism.
79
80  Returns:
81    List[FrameInfo]
82  """
83  result_frames = []
84  for filename, line_number, function_name, text in reversed(tb):
85
86    loc = origin_info.LineLocation(filename=filename, lineno=line_number)
87    if loc in source_map:
88      origin = source_map[loc]
89      fi = FrameInfo(
90          filename=origin.loc.filename,
91          lineno=origin.loc.lineno,
92          function_name=origin.function_name,
93          code=origin.source_code_line,
94          is_converted=True,
95          is_allowlisted=False)
96      result_frames.append(fi)
97      break
98
99    if filename == converter_filename:
100      if result_frames:
101        prev = result_frames[-1]
102        assert not prev.is_converted  # See the if above.
103        fi = FrameInfo(
104            filename=prev.filename,
105            lineno=prev.lineno,
106            function_name=prev.function_name,
107            code=prev.code,
108            is_converted=False,
109            is_allowlisted=True)
110        result_frames[-1] = fi
111      continue
112
113    fi = FrameInfo(
114        filename=filename,
115        lineno=line_number,
116        function_name=function_name,
117        code=text,
118        is_converted=False,
119        is_allowlisted=False)
120    result_frames.append(fi)
121
122  return tuple(result_frames)
123
124
125KNOWN_STRING_CONSTRUCTOR_ERRORS = (
126    AssertionError,
127    AttributeError,
128    NameError,
129    NotImplementedError,
130    RuntimeError,
131    StopIteration,
132    TypeError,
133    UnboundLocalError,
134    ValueError,
135)
136
137
138# KeyError escapes newlines in strings. We create a special subclass
139# that doesn't do that. Overriding the name for display purposes; hopefully
140# that won't create too many surprises.
141class MultilineMessageKeyError(KeyError):
142
143  def __init__(self, message, original_key):
144    super(MultilineMessageKeyError, self).__init__(original_key)
145    self.__message = message
146
147  def __str__(self):
148    return self.__message
149
150MultilineMessageKeyError.__name__ = KeyError.__name__
151
152
153class ErrorMetadataBase(object):
154  """Container objects attached to exceptions raised in user code.
155
156  This metadata allows re-raising exceptions that occur in generated code, with
157  a custom error message that includes a stack trace relative to user-readable
158  code from which the generated code originated.
159  """
160
161  __slots__ = ('translated_stack', 'cause_message')
162
163  def __init__(self, callsite_tb, cause_metadata, cause_message, source_map,
164               converter_filename):
165    translated_stack = _stack_trace_inside_mapped_code(
166        callsite_tb, source_map, converter_filename)
167
168    if cause_metadata is None:
169      self.translated_stack = translated_stack
170      self.cause_message = cause_message
171    else:
172      # Daisy chain the translated stacks.
173      self.translated_stack = (
174          cause_metadata.translated_stack + (translated_stack[-1],))
175      self.cause_message = cause_metadata.cause_message
176
177  def get_message(self):
178    """Returns the message for the underlying exception."""
179    lines = []
180
181    lines.append('in user code:')
182    lines.append('')
183
184    for frame_info in reversed(self.translated_stack):
185      formatted_line = '    {}:{} {}'.format(frame_info.filename,
186                                             frame_info.lineno,
187                                             frame_info.function_name)
188      if frame_info.is_converted:
189        formatted_line += '  *'
190      elif frame_info.is_allowlisted:
191        formatted_line += '  **'
192      lines.append(formatted_line)
193
194      if frame_info.code is None:
195        code_snippet = '<source unavailable>'
196      else:
197        code_snippet = frame_info.code.strip()
198      lines.append('        {}'.format(code_snippet))
199
200    lines.append('')
201
202    message_lines = self.cause_message.split('\n')
203    for i in range(len(message_lines)):
204      message_lines[i] = '    ' + message_lines[i]
205    lines.extend(message_lines)
206
207    lines.append('')
208
209    return '\n'.join(lines)
210
211  def create_exception(self, source_error):
212    preferred_type = type(source_error)
213    if preferred_type.__init__ is Exception.__init__:
214      return preferred_type(self.get_message())
215    if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS:
216      return preferred_type(self.get_message())
217    elif preferred_type is KeyError:
218      return MultilineMessageKeyError(self.get_message(), self.cause_message)
219    return None
220
221  def to_exception(self, source_error):
222    exc = self.create_exception(source_error)
223    exc.__suppress_context__ = True
224    exc.ag_error_metadata = self
225    return exc
226