1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Error rewriting logic. 16 17Contains the functions responsible for rewriting tracebacks of errors raised 18in AutoGraph (AG) code to refer to user written code, so that errors only refer 19to the original user code. 20 21When 'user code' is used in comments it refers to the original source code that 22the user wrote and is converting using AutoGraph. 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28 29import contextlib 30import logging 31import sys 32import traceback 33 34from tensorflow.python.autograph.pyct import origin_info 35from tensorflow.python.framework import errors_impl 36 37# TODO(mdan): Add a superclass common to all errors. 38 39 40class GraphConstructionError(Exception): 41 """Error for graph construction errors from AutoGraph generated code.""" 42 43 def __init__(self, original_error, custom_traceback): 44 self.original_error = original_error 45 self.custom_traceback = custom_traceback 46 super(GraphConstructionError, self).__init__() 47 48 def __str__(self): 49 traceback_str = ''.join(traceback.format_list(self.custom_traceback)) 50 return ('Traceback (most recent call last):\n' + traceback_str + '\n' + str( 51 self.original_error) + '\n') 52 53 54class TfRuntimeError(Exception): 55 """Error wrapper for runtime errors raised by AutoGraph generated code.""" 56 57 def __init__(self, op_name, op_message, custom_traceback): 58 self.op_name = op_name 59 self.op_message = op_message 60 self.custom_traceback = custom_traceback 61 super(TfRuntimeError, self).__init__() 62 63 def __str__(self): 64 message = '%s\n\nCaused by op %r, defined at:\n' % (self.op_message, 65 self.op_name) 66 return message + ''.join(traceback.format_list(self.custom_traceback)) 67 68 69def _rewrite_tb(source_map, tb): 70 """Rewrites code references in a traceback. 71 72 Args: 73 source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping 74 locations to their origin 75 tb: List[Tuple[Text, Text, Text, Text]], consistent with 76 traceback.extract_tb. 77 Returns: 78 List[Tuple[Text, Text, Text, Text]], the rewritten traceback 79 """ 80 new_tb = [] 81 for frame in tb: 82 filename, lineno, _, _ = frame 83 loc = origin_info.LineLocation(filename, lineno) 84 origin = source_map.get(loc) 85 if origin is not None: 86 new_tb.append(origin.as_frame()) 87 else: 88 new_tb.append(frame) 89 return new_tb 90 91 92# TODO(mdan): rename to raise_* 93def rewrite_graph_construction_error(source_map): 94 """Rewrites errors raised by non-AG APIs inside AG generated code. 95 96 This is called from the except handler inside an AutoGraph generated function 97 (that is, during exception handling). Only rewrites the frames corresponding 98 to the function that this is called from, so each function is responsible 99 to call this to have its own frames rewritten. 100 101 This function always raises an error. 102 103 Args: 104 source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source 105 map belonging to the calling function 106 107 Raises: 108 GraphConstructionError: The rewritten underlying error. 109 Exception: The underlying error, if it could not be rewritten. 110 """ 111 error_info = sys.exc_info() 112 _, original_error, e_traceback = error_info 113 assert original_error is not None 114 try: 115 current_traceback = _cut_traceback_loops(source_map, 116 traceback.extract_tb(e_traceback)) 117 if isinstance(original_error, GraphConstructionError): 118 # TODO(mdan): This is incomplete. 119 # The error might have bubbled through a non-converted function. 120 previous_traceback = original_error.custom_traceback 121 cleaned_traceback = [current_traceback[0]] + previous_traceback 122 else: 123 cleaned_traceback = current_traceback 124 125 cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) 126 127 if isinstance(original_error, GraphConstructionError): 128 original_error.custom_traceback = cleaned_traceback 129 new_error = original_error 130 else: 131 new_error = GraphConstructionError(original_error, cleaned_traceback) 132 except Exception: 133 logging.exception('Error while rewriting AutoGraph error:') 134 # TODO(mdan): Should reraise here, removing the top frame as well. 135 raise original_error 136 else: 137 raise new_error 138 finally: 139 # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info. 140 del e_traceback 141 142 143def _cut_traceback_loops(source_map, original_traceback): 144 """Check for cases where we leave a user method and re-enter it. 145 146 This is done by looking at the function names when the filenames are from any 147 files the user code is in. If we find a case where we return to a user method 148 after leaving it then we cut out the frames in between because we assume this 149 means these in between frames are from internal AutoGraph code that shouldn't 150 be included. 151 152 An example of this is: 153 154 File "file1.py", line 57, in my_func 155 ... 156 File "control_flow_ops.py", line 231, in cond 157 ... 158 File "control_flow_ops.py", line 1039, in inner_cond 159 ... 160 File "file1.py", line 68, in my_func 161 ... 162 163 Where we would remove the control_flow_ops.py frames because we re-enter 164 my_func in file1.py. 165 166 The source map keys are (file_path, line_number) so get the set of all user 167 file_paths. 168 169 Args: 170 source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping 171 locations to their origin 172 original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with 173 traceback.extract_tb. 174 175 Returns: 176 List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed. 177 """ 178 all_user_files = set(loc.filename for loc in source_map) 179 cleaned_traceback = [] 180 last_user_frame_index = None 181 last_user_user_file_path = None 182 # TODO(mdan): Simplify this logic. 183 for fi, frame in enumerate(original_traceback): 184 frame_file_path, lineno, _, _ = frame 185 src_map_key = origin_info.LineLocation(frame_file_path, lineno) 186 if frame_file_path in all_user_files: 187 if src_map_key in source_map: 188 if (last_user_frame_index is not None and 189 last_user_user_file_path == frame_file_path): 190 cleaned_traceback = cleaned_traceback[:last_user_frame_index] 191 last_user_frame_index = fi 192 last_user_user_file_path = frame_file_path 193 cleaned_traceback.append(frame) 194 return cleaned_traceback 195 196 197# TODO(mdan): This should be consistent with rewrite_graph_construction_error 198# Both should either raise or return. 199def rewrite_tf_runtime_error(error, source_map): 200 """Rewrites TensorFlow runtime errors raised by ops created in AG code. 201 202 Args: 203 error: tf.OpError 204 source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo] 205 206 Returns: 207 TfRuntimeError, the rewritten underlying error. 208 """ 209 try: 210 cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback) 211 cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) 212 213 op_name = error.op.name 214 op_message = error.message 215 rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback) 216 return rewritten_error 217 except Exception: # pylint: disable=broad-except 218 logging.exception('Error while rewriting AutoGraph error:') 219 return error 220 221 222# TODO(znado): Add arg to enable different levels of error rewriting. 223@contextlib.contextmanager 224def improved_errors(converted_function): 225 """Context manager that rewrites runtime errors. 226 227 This context manager will rewrite runtime errors so that their traceback 228 is relative to the original code before conversion. 229 230 Use with the output of to_graph, and wrap the execution of respective ops. 231 Example: 232 233 converted_my_func = ag.to_graph(my_func) 234 ops = converted_my_func(...) 235 236 with ag.improved_errors(converted_my_func): 237 sess.run(ops) 238 239 Args: 240 converted_function: Callable[..., Any], the output of a to_graph call 241 242 Yields: 243 None 244 245 Raises: 246 TfRuntimeError: if any OpError originates in the converted code, it will 247 be wrapped into a TfRuntimeError 248 ValueError: If converted_function is not generated by AutoGraph 249 """ 250 if (getattr(converted_function, 'ag_source_map', None) is None or 251 not isinstance(converted_function.ag_source_map, dict)): 252 raise ValueError( 253 'converted_function must be the result of an autograph.to_graph call') 254 try: 255 yield 256 except errors_impl.OpError as e: 257 raise rewrite_tf_runtime_error(e, converted_function.ag_source_map) 258