• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Error rewriting logic.
16
17Contains the functions responsible for rewriting tracebacks of errors raised
18in AutoGraph (AG) code to refer to user written code, so that errors only refer
19to the original user code.
20
21When 'user code' is used in comments it refers to the original source code that
22the user wrote and is converting using AutoGraph.
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28
29import contextlib
30import logging
31import sys
32import traceback
33
34from tensorflow.python.autograph.pyct import origin_info
35from tensorflow.python.framework import errors_impl
36
37# TODO(mdan): Add a superclass common to all errors.
38
39
40class GraphConstructionError(Exception):
41  """Error for graph construction errors from AutoGraph generated code."""
42
43  def __init__(self, original_error, custom_traceback):
44    self.original_error = original_error
45    self.custom_traceback = custom_traceback
46    super(GraphConstructionError, self).__init__()
47
48  def __str__(self):
49    traceback_str = ''.join(traceback.format_list(self.custom_traceback))
50    return ('Traceback (most recent call last):\n' + traceback_str + '\n' + str(
51        self.original_error) + '\n')
52
53
54class TfRuntimeError(Exception):
55  """Error wrapper for runtime errors raised by AutoGraph generated code."""
56
57  def __init__(self, op_name, op_message, custom_traceback):
58    self.op_name = op_name
59    self.op_message = op_message
60    self.custom_traceback = custom_traceback
61    super(TfRuntimeError, self).__init__()
62
63  def __str__(self):
64    message = '%s\n\nCaused by op %r, defined at:\n' % (self.op_message,
65                                                        self.op_name)
66    return message + ''.join(traceback.format_list(self.custom_traceback))
67
68
69def _rewrite_tb(source_map, tb):
70  """Rewrites code references in a traceback.
71
72  Args:
73    source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
74        locations to their origin
75    tb: List[Tuple[Text, Text, Text, Text]], consistent with
76        traceback.extract_tb.
77  Returns:
78    List[Tuple[Text, Text, Text, Text]], the rewritten traceback
79  """
80  new_tb = []
81  for frame in tb:
82    filename, lineno, _, _ = frame
83    loc = origin_info.LineLocation(filename, lineno)
84    origin = source_map.get(loc)
85    if origin is not None:
86      new_tb.append(origin.as_frame())
87    else:
88      new_tb.append(frame)
89  return new_tb
90
91
92# TODO(mdan): rename to raise_*
93def rewrite_graph_construction_error(source_map):
94  """Rewrites errors raised by non-AG APIs inside AG generated code.
95
96  This is called from the except handler inside an AutoGraph generated function
97  (that is, during exception handling). Only rewrites the frames corresponding
98  to the function that this is called from, so each function is responsible
99  to call this to have its own frames rewritten.
100
101  This function always raises an error.
102
103  Args:
104    source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source
105        map belonging to the calling function
106
107  Raises:
108    GraphConstructionError: The rewritten underlying error.
109    Exception: The underlying error, if it could not be rewritten.
110  """
111  error_info = sys.exc_info()
112  _, original_error, e_traceback = error_info
113  assert original_error is not None
114  try:
115    current_traceback = _cut_traceback_loops(source_map,
116                                             traceback.extract_tb(e_traceback))
117    if isinstance(original_error, GraphConstructionError):
118      # TODO(mdan): This is incomplete.
119      # The error might have bubbled through a non-converted function.
120      previous_traceback = original_error.custom_traceback
121      cleaned_traceback = [current_traceback[0]] + previous_traceback
122    else:
123      cleaned_traceback = current_traceback
124
125    cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
126
127    if isinstance(original_error, GraphConstructionError):
128      original_error.custom_traceback = cleaned_traceback
129      new_error = original_error
130    else:
131      new_error = GraphConstructionError(original_error, cleaned_traceback)
132  except Exception:
133    logging.exception('Error while rewriting AutoGraph error:')
134    # TODO(mdan): Should reraise here, removing the top frame as well.
135    raise original_error
136  else:
137    raise new_error
138  finally:
139    # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info.
140    del e_traceback
141
142
143def _cut_traceback_loops(source_map, original_traceback):
144  """Check for cases where we leave a user method and re-enter it.
145
146  This is done by looking at the function names when the filenames are from any
147  files the user code is in.  If we find a case where we return to a user method
148  after leaving it then we cut out the frames in between because we assume this
149  means these in between frames are from internal AutoGraph code that shouldn't
150  be included.
151
152  An example of this is:
153
154   File "file1.py", line 57, in my_func
155     ...
156   File "control_flow_ops.py", line 231, in cond
157     ...
158   File "control_flow_ops.py", line 1039, in inner_cond
159     ...
160   File "file1.py", line 68, in my_func
161     ...
162
163  Where we would remove the control_flow_ops.py frames because we re-enter
164  my_func in file1.py.
165
166  The source map keys are (file_path, line_number) so get the set of all user
167  file_paths.
168
169  Args:
170    source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
171      locations to their origin
172    original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with
173      traceback.extract_tb.
174
175  Returns:
176    List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed.
177  """
178  all_user_files = set(loc.filename for loc in source_map)
179  cleaned_traceback = []
180  last_user_frame_index = None
181  last_user_user_file_path = None
182  # TODO(mdan): Simplify this logic.
183  for fi, frame in enumerate(original_traceback):
184    frame_file_path, lineno, _, _ = frame
185    src_map_key = origin_info.LineLocation(frame_file_path, lineno)
186    if frame_file_path in all_user_files:
187      if src_map_key in source_map:
188        if (last_user_frame_index is not None and
189            last_user_user_file_path == frame_file_path):
190          cleaned_traceback = cleaned_traceback[:last_user_frame_index]
191      last_user_frame_index = fi
192      last_user_user_file_path = frame_file_path
193    cleaned_traceback.append(frame)
194  return cleaned_traceback
195
196
197# TODO(mdan): This should be consistent with rewrite_graph_construction_error
198# Both should either raise or return.
199def rewrite_tf_runtime_error(error, source_map):
200  """Rewrites TensorFlow runtime errors raised by ops created in AG code.
201
202  Args:
203    error: tf.OpError
204    source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo]
205
206  Returns:
207    TfRuntimeError, the rewritten underlying error.
208  """
209  try:
210    cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback)
211    cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
212
213    op_name = error.op.name
214    op_message = error.message
215    rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback)
216    return rewritten_error
217  except Exception:  # pylint: disable=broad-except
218    logging.exception('Error while rewriting AutoGraph error:')
219    return error
220
221
222# TODO(znado): Add arg to enable different levels of error rewriting.
223@contextlib.contextmanager
224def improved_errors(converted_function):
225  """Context manager that rewrites runtime errors.
226
227  This context manager will rewrite runtime errors so that their traceback
228  is relative to the original code before conversion.
229
230  Use with the output of to_graph, and wrap the execution of respective ops.
231  Example:
232
233    converted_my_func = ag.to_graph(my_func)
234    ops = converted_my_func(...)
235
236    with ag.improved_errors(converted_my_func):
237      sess.run(ops)
238
239  Args:
240    converted_function: Callable[..., Any], the output of a to_graph call
241
242  Yields:
243    None
244
245  Raises:
246    TfRuntimeError: if any OpError originates in the converted code, it will
247        be wrapped into a TfRuntimeError
248    ValueError: If converted_function is not generated by AutoGraph
249  """
250  if (getattr(converted_function, 'ag_source_map', None) is None or
251      not isinstance(converted_function.ag_source_map, dict)):
252    raise ValueError(
253        'converted_function must be the result of an autograph.to_graph call')
254  try:
255    yield
256  except errors_impl.OpError as e:
257    raise rewrite_tf_runtime_error(e, converted_function.ag_source_map)
258