• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code transformation exceptions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.autograph.pyct import origin_info
24from tensorflow.python.util import traceback_utils
25
26
27class FrameInfo(
28    collections.namedtuple('FrameInfo',
29                           ('filename', 'lineno', 'function_name', 'code',
30                            'is_converted', 'is_allowlisted'))):
31
32  __slots__ = ()
33
34
35def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
36  """Summarizes inner traceback frames up to the call to a given function.
37
38  This functions locates the innermost (i.e. most recent) frame that corresponds
39  to code that can be mapped by source_map originated from, and returns a
40  translated stack trace ending at that frame. If no such frame is found, the
41  entire stack trace is summarized.
42
43  For example, the following code:
44
45    def f():
46      for i in tf.range(1):
47        z = y + i  # z only defined here
48
49  Would generate this traceback:
50
51    <converted code>
52        ag__.for_stmt(...)
53    <for_stmt>
54        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
55    <_known_len_tf_for_stmt>
56        _disallow_undefs_into_loop(*init_state)
57    <_disallow_undefs_into_loop>
58        raise ...
59
60  Which is then processed into:
61
62    <f>
63        for i in tf.range(1):
64    <for_stmt>
65        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
66    <_known_len_tf_for_stmt>
67        _disallow_undefs_into_loop(*init_state)
68    <_disallow_undefs_into_loop>
69        raise ...
70
71  Args:
72    tb: traceback.FrameSummary, The traceback corresponding to an error.
73      Typically, the output of traceback.Summary.extract(capture_locals=True).
74    source_map: Dict[LineLocation, OriginInfo], a source map as created by
75      origin_info.create_source_map.
76    converter_filename: str, the file path of the converted module. Call frames
77      corresponding to this module are elided and their preceding frames are
78      marked as allowlisted. Note that frames enclosing converted code are
79      dropped using a different mechanism.
80
81  Returns:
82    List[FrameInfo]
83  """
84  result_frames = []
85  for filename, line_number, function_name, text in reversed(tb):
86
87    loc = origin_info.LineLocation(filename=filename, lineno=line_number)
88    if loc in source_map:
89      origin = source_map[loc]
90      fi = FrameInfo(
91          filename=origin.loc.filename,
92          lineno=origin.loc.lineno,
93          function_name=origin.function_name,
94          code=origin.source_code_line,
95          is_converted=True,
96          is_allowlisted=False)
97      result_frames.append(fi)
98      break
99
100    if filename == converter_filename:
101      if result_frames:
102        prev = result_frames[-1]
103        assert not prev.is_converted  # See the if above.
104        fi = FrameInfo(
105            filename=prev.filename,
106            lineno=prev.lineno,
107            function_name=prev.function_name,
108            code=prev.code,
109            is_converted=False,
110            is_allowlisted=True)
111        result_frames[-1] = fi
112      continue
113
114    fi = FrameInfo(
115        filename=filename,
116        lineno=line_number,
117        function_name=function_name,
118        code=text,
119        is_converted=False,
120        is_allowlisted=False)
121    result_frames.append(fi)
122
123  return tuple(result_frames)
124
125
126KNOWN_STRING_CONSTRUCTOR_ERRORS = (
127    AssertionError,
128    AttributeError,
129    NameError,
130    NotImplementedError,
131    RuntimeError,
132    StopIteration,
133    TypeError,
134    UnboundLocalError,
135    ValueError,
136)
137
138
139# KeyError escapes newlines in strings. We create a special subclass
140# that doesn't do that. Overriding the name for display purposes; hopefully
141# that won't create too many surprises.
142class MultilineMessageKeyError(KeyError):
143
144  def __init__(self, message, original_key):
145    super(MultilineMessageKeyError, self).__init__(original_key)
146    self.__message = message
147
148  def __str__(self):
149    return self.__message
150
151MultilineMessageKeyError.__name__ = KeyError.__name__
152
153
154class ErrorMetadataBase(object):
155  """Container objects attached to exceptions raised in user code.
156
157  This metadata allows re-raising exceptions that occur in generated code, with
158  a custom error message that includes a stack trace relative to user-readable
159  code from which the generated code originated.
160  """
161
162  __slots__ = ('translated_stack', 'cause_message')
163
164  def __init__(self, callsite_tb, cause_metadata, cause_message, source_map,
165               converter_filename):
166    translated_stack = _stack_trace_inside_mapped_code(
167        callsite_tb, source_map, converter_filename)
168
169    if cause_metadata is None:
170      self.translated_stack = translated_stack
171      self.cause_message = cause_message
172    else:
173      # Daisy chain the translated stacks.
174      self.translated_stack = (
175          cause_metadata.translated_stack + (translated_stack[-1],))
176      self.cause_message = cause_metadata.cause_message
177
178  def get_message(self):
179    """Returns the message for the underlying exception."""
180    lines = []
181
182    lines.append('in user code:')
183    lines.append('')
184
185    for frame_info in reversed(self.translated_stack):
186      if (traceback_utils.is_traceback_filtering_enabled() and
187          not traceback_utils.include_frame(frame_info.filename)):
188        continue
189
190      formatted_line = '    {}:{} {}'.format(frame_info.filename,
191                                             frame_info.lineno,
192                                             frame_info.function_name)
193      if frame_info.is_converted:
194        formatted_line += '  *'
195      elif frame_info.is_allowlisted:
196        formatted_line += '  **'
197      lines.append(formatted_line)
198
199      if frame_info.code is None:
200        code_snippet = '<source unavailable>'
201      else:
202        code_snippet = frame_info.code.strip()
203      lines.append('        {}'.format(code_snippet))
204
205    lines.append('')
206
207    message_lines = self.cause_message.split('\n')
208    for i in range(len(message_lines)):
209      message_lines[i] = '    ' + message_lines[i]
210    lines.extend(message_lines)
211
212    lines.append('')
213
214    return '\n'.join(lines)
215
216  def create_exception(self, source_error):
217    preferred_type = type(source_error)
218    if preferred_type.__init__ is Exception.__init__:
219      return preferred_type(self.get_message())
220    if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS:
221      return preferred_type(self.get_message())
222    elif preferred_type is KeyError:
223      return MultilineMessageKeyError(self.get_message(), self.cause_message)
224    return None
225
226  def to_exception(self, source_error):
227    exc = self.create_exception(source_error)
228    exc.__suppress_context__ = True
229    exc.ag_error_metadata = self
230    return exc
231