1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Code transformation exceptions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.autograph.pyct import origin_info 24from tensorflow.python.util import traceback_utils 25 26 27class FrameInfo( 28 collections.namedtuple('FrameInfo', 29 ('filename', 'lineno', 'function_name', 'code', 30 'is_converted', 'is_allowlisted'))): 31 32 __slots__ = () 33 34 35def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): 36 """Summarizes inner traceback frames up to the call to a given function. 37 38 This functions locates the innermost (i.e. most recent) frame that corresponds 39 to code that can be mapped by source_map originated from, and returns a 40 translated stack trace ending at that frame. If no such frame is found, the 41 entire stack trace is summarized. 42 43 For example, the following code: 44 45 def f(): 46 for i in tf.range(1): 47 z = y + i # z only defined here 48 49 Would generate this traceback: 50 51 <converted code> 52 ag__.for_stmt(...) 53 <for_stmt> 54 return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) 55 <_known_len_tf_for_stmt> 56 _disallow_undefs_into_loop(*init_state) 57 <_disallow_undefs_into_loop> 58 raise ... 59 60 Which is then processed into: 61 62 <f> 63 for i in tf.range(1): 64 <for_stmt> 65 return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) 66 <_known_len_tf_for_stmt> 67 _disallow_undefs_into_loop(*init_state) 68 <_disallow_undefs_into_loop> 69 raise ... 70 71 Args: 72 tb: traceback.FrameSummary, The traceback corresponding to an error. 73 Typically, the output of traceback.Summary.extract(capture_locals=True). 74 source_map: Dict[LineLocation, OriginInfo], a source map as created by 75 origin_info.create_source_map. 76 converter_filename: str, the file path of the converted module. Call frames 77 corresponding to this module are elided and their preceding frames are 78 marked as allowlisted. Note that frames enclosing converted code are 79 dropped using a different mechanism. 80 81 Returns: 82 List[FrameInfo] 83 """ 84 result_frames = [] 85 for filename, line_number, function_name, text in reversed(tb): 86 87 loc = origin_info.LineLocation(filename=filename, lineno=line_number) 88 if loc in source_map: 89 origin = source_map[loc] 90 fi = FrameInfo( 91 filename=origin.loc.filename, 92 lineno=origin.loc.lineno, 93 function_name=origin.function_name, 94 code=origin.source_code_line, 95 is_converted=True, 96 is_allowlisted=False) 97 result_frames.append(fi) 98 break 99 100 if filename == converter_filename: 101 if result_frames: 102 prev = result_frames[-1] 103 assert not prev.is_converted # See the if above. 104 fi = FrameInfo( 105 filename=prev.filename, 106 lineno=prev.lineno, 107 function_name=prev.function_name, 108 code=prev.code, 109 is_converted=False, 110 is_allowlisted=True) 111 result_frames[-1] = fi 112 continue 113 114 fi = FrameInfo( 115 filename=filename, 116 lineno=line_number, 117 function_name=function_name, 118 code=text, 119 is_converted=False, 120 is_allowlisted=False) 121 result_frames.append(fi) 122 123 return tuple(result_frames) 124 125 126KNOWN_STRING_CONSTRUCTOR_ERRORS = ( 127 AssertionError, 128 AttributeError, 129 NameError, 130 NotImplementedError, 131 RuntimeError, 132 StopIteration, 133 TypeError, 134 UnboundLocalError, 135 ValueError, 136) 137 138 139# KeyError escapes newlines in strings. We create a special subclass 140# that doesn't do that. Overriding the name for display purposes; hopefully 141# that won't create too many surprises. 142class MultilineMessageKeyError(KeyError): 143 144 def __init__(self, message, original_key): 145 super(MultilineMessageKeyError, self).__init__(original_key) 146 self.__message = message 147 148 def __str__(self): 149 return self.__message 150 151MultilineMessageKeyError.__name__ = KeyError.__name__ 152 153 154class ErrorMetadataBase(object): 155 """Container objects attached to exceptions raised in user code. 156 157 This metadata allows re-raising exceptions that occur in generated code, with 158 a custom error message that includes a stack trace relative to user-readable 159 code from which the generated code originated. 160 """ 161 162 __slots__ = ('translated_stack', 'cause_message') 163 164 def __init__(self, callsite_tb, cause_metadata, cause_message, source_map, 165 converter_filename): 166 translated_stack = _stack_trace_inside_mapped_code( 167 callsite_tb, source_map, converter_filename) 168 169 if cause_metadata is None: 170 self.translated_stack = translated_stack 171 self.cause_message = cause_message 172 else: 173 # Daisy chain the translated stacks. 174 self.translated_stack = ( 175 cause_metadata.translated_stack + (translated_stack[-1],)) 176 self.cause_message = cause_metadata.cause_message 177 178 def get_message(self): 179 """Returns the message for the underlying exception.""" 180 lines = [] 181 182 lines.append('in user code:') 183 lines.append('') 184 185 for frame_info in reversed(self.translated_stack): 186 if (traceback_utils.is_traceback_filtering_enabled() and 187 not traceback_utils.include_frame(frame_info.filename)): 188 continue 189 190 formatted_line = ' {}:{} {}'.format(frame_info.filename, 191 frame_info.lineno, 192 frame_info.function_name) 193 if frame_info.is_converted: 194 formatted_line += ' *' 195 elif frame_info.is_allowlisted: 196 formatted_line += ' **' 197 lines.append(formatted_line) 198 199 if frame_info.code is None: 200 code_snippet = '<source unavailable>' 201 else: 202 code_snippet = frame_info.code.strip() 203 lines.append(' {}'.format(code_snippet)) 204 205 lines.append('') 206 207 message_lines = self.cause_message.split('\n') 208 for i in range(len(message_lines)): 209 message_lines[i] = ' ' + message_lines[i] 210 lines.extend(message_lines) 211 212 lines.append('') 213 214 return '\n'.join(lines) 215 216 def create_exception(self, source_error): 217 preferred_type = type(source_error) 218 if preferred_type.__init__ is Exception.__init__: 219 return preferred_type(self.get_message()) 220 if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS: 221 return preferred_type(self.get_message()) 222 elif preferred_type is KeyError: 223 return MultilineMessageKeyError(self.get_message(), self.cause_message) 224 return None 225 226 def to_exception(self, source_error): 227 exc = self.create_exception(source_error) 228 exc.__suppress_context__ = True 229 exc.ag_error_metadata = self 230 return exc 231