1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Code transformation exceptions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.autograph.pyct import origin_info 24 25 26class FrameInfo( 27 collections.namedtuple('FrameInfo', 28 ('filename', 'lineno', 'function_name', 'code', 29 'is_converted', 'is_allowlisted'))): 30 31 __slots__ = () 32 33 34def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): 35 """Summarizes inner traceback frames up to the call to a given function. 36 37 This functions locates the innermost (i.e. most recent) frame that corresponds 38 to code that can be mapped by source_map originated from, and returns a 39 translated stack trace ending at that frame. If no such frame is found, the 40 entire stack trace is summarized. 41 42 For example, the following code: 43 44 def f(): 45 for i in tf.range(1): 46 z = y + i # z only defined here 47 48 Would generate this traceback: 49 50 <converted code> 51 ag__.for_stmt(...) 52 <for_stmt> 53 return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) 54 <_known_len_tf_for_stmt> 55 _disallow_undefs_into_loop(*init_state) 56 <_disallow_undefs_into_loop> 57 raise ... 58 59 Which is then processed into: 60 61 <f> 62 for i in tf.range(1): 63 <for_stmt> 64 return _known_len_tf_for_stmt(iter_, extra_test, body, init_state) 65 <_known_len_tf_for_stmt> 66 _disallow_undefs_into_loop(*init_state) 67 <_disallow_undefs_into_loop> 68 raise ... 69 70 Args: 71 tb: traceback.FrameSummary, The traceback corresponding to an error. 72 Typically, the output of traceback.Summary.extract(capture_locals=True). 73 source_map: Dict[LineLocation, OriginInfo], a source map as created by 74 origin_info.create_source_map. 75 converter_filename: str, the file path of the converted module. Call frames 76 corresponding to this module are elided and their preceding frames are 77 marked as allowlisted. Note that frames enclosing converted code are 78 dropped using a different mechanism. 79 80 Returns: 81 List[FrameInfo] 82 """ 83 result_frames = [] 84 for filename, line_number, function_name, text in reversed(tb): 85 86 loc = origin_info.LineLocation(filename=filename, lineno=line_number) 87 if loc in source_map: 88 origin = source_map[loc] 89 fi = FrameInfo( 90 filename=origin.loc.filename, 91 lineno=origin.loc.lineno, 92 function_name=origin.function_name, 93 code=origin.source_code_line, 94 is_converted=True, 95 is_allowlisted=False) 96 result_frames.append(fi) 97 break 98 99 if filename == converter_filename: 100 if result_frames: 101 prev = result_frames[-1] 102 assert not prev.is_converted # See the if above. 103 fi = FrameInfo( 104 filename=prev.filename, 105 lineno=prev.lineno, 106 function_name=prev.function_name, 107 code=prev.code, 108 is_converted=False, 109 is_allowlisted=True) 110 result_frames[-1] = fi 111 continue 112 113 fi = FrameInfo( 114 filename=filename, 115 lineno=line_number, 116 function_name=function_name, 117 code=text, 118 is_converted=False, 119 is_allowlisted=False) 120 result_frames.append(fi) 121 122 return tuple(result_frames) 123 124 125KNOWN_STRING_CONSTRUCTOR_ERRORS = ( 126 AssertionError, 127 AttributeError, 128 NameError, 129 NotImplementedError, 130 RuntimeError, 131 StopIteration, 132 TypeError, 133 UnboundLocalError, 134 ValueError, 135) 136 137 138# KeyError escapes newlines in strings. We create a special subclass 139# that doesn't do that. Overriding the name for display purposes; hopefully 140# that won't create too many surprises. 141class MultilineMessageKeyError(KeyError): 142 143 def __init__(self, message, original_key): 144 super(MultilineMessageKeyError, self).__init__(original_key) 145 self.__message = message 146 147 def __str__(self): 148 return self.__message 149 150MultilineMessageKeyError.__name__ = KeyError.__name__ 151 152 153class ErrorMetadataBase(object): 154 """Container objects attached to exceptions raised in user code. 155 156 This metadata allows re-raising exceptions that occur in generated code, with 157 a custom error message that includes a stack trace relative to user-readable 158 code from which the generated code originated. 159 """ 160 161 __slots__ = ('translated_stack', 'cause_message') 162 163 def __init__(self, callsite_tb, cause_metadata, cause_message, source_map, 164 converter_filename): 165 translated_stack = _stack_trace_inside_mapped_code( 166 callsite_tb, source_map, converter_filename) 167 168 if cause_metadata is None: 169 self.translated_stack = translated_stack 170 self.cause_message = cause_message 171 else: 172 # Daisy chain the translated stacks. 173 self.translated_stack = ( 174 cause_metadata.translated_stack + (translated_stack[-1],)) 175 self.cause_message = cause_metadata.cause_message 176 177 def get_message(self): 178 """Returns the message for the underlying exception.""" 179 lines = [] 180 181 lines.append('in user code:') 182 lines.append('') 183 184 for frame_info in reversed(self.translated_stack): 185 formatted_line = ' {}:{} {}'.format(frame_info.filename, 186 frame_info.lineno, 187 frame_info.function_name) 188 if frame_info.is_converted: 189 formatted_line += ' *' 190 elif frame_info.is_allowlisted: 191 formatted_line += ' **' 192 lines.append(formatted_line) 193 194 if frame_info.code is None: 195 code_snippet = '<source unavailable>' 196 else: 197 code_snippet = frame_info.code.strip() 198 lines.append(' {}'.format(code_snippet)) 199 200 lines.append('') 201 202 message_lines = self.cause_message.split('\n') 203 for i in range(len(message_lines)): 204 message_lines[i] = ' ' + message_lines[i] 205 lines.extend(message_lines) 206 207 lines.append('') 208 209 return '\n'.join(lines) 210 211 def create_exception(self, source_error): 212 preferred_type = type(source_error) 213 if preferred_type.__init__ is Exception.__init__: 214 return preferred_type(self.get_message()) 215 if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS: 216 return preferred_type(self.get_message()) 217 elif preferred_type is KeyError: 218 return MultilineMessageKeyError(self.get_message(), self.cause_message) 219 return None 220 221 def to_exception(self, source_error): 222 exc = self.create_exception(source_error) 223 exc.__suppress_context__ = True 224 exc.ag_error_metadata = self 225 return exc 226