1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Function for interpolating formatted errors from the TensorFlow runtime. 16 17Exposes the function `interpolate` to interpolate messages with tags of the form 18{{type name}}. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import itertools 27import os 28import re 29 30import six 31 32from tensorflow.core.protobuf import graph_debug_info_pb2 33 34_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?" 35_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX) 36_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) 37_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL) 38 39_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"]) 40 41 42# Remove the last three path components from this module's file (i.e. 43# python/framework/error_interpolation.py) so that we have an absolute path 44# prefix to the root of the installation. 45_FRAMEWORK_COMMON_PREFIX = os.path.dirname( 46 os.path.dirname(os.path.dirname(__file__))) 47 48# Sub-directories under the common prefix that are considered part of the 49# framework. 50_FRAMEWORK_PATH_PREFIXES = [ 51 os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep, 52 os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep, 53] 54 55# Patterns of filename patterns that should be considered internal to 56# the TensorFlow framework. 57_FRAMEWORK_FILENAME_PATTERNS = [ 58 re.compile(r"<embedded"), 59] 60 61# Patterns of filename patterns that should be considered external to 62# TensorFlow regardless of framework prefix match. 63_EXTERNAL_FILENAME_PATTERNS = [ 64 # Explicitly treat test frames as not part of the framework. 65 re.compile(r"_test\.py$"), 66] 67 68 69def parse_message(message): 70 """Parses the message. 71 72 Splits the message into separators and tags. Tags are named tuples 73 representing the string {{type name}} and they are separated by 74 separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are 75 two tags and three separators. The separators are the numeric characters. 76 77 Args: 78 message: String to parse 79 80 Returns: 81 (list of separator strings, list of _ParseTags). 82 83 For example, if message is "123{{node Foo}}456" then this function 84 returns (["123", "456"], [_ParseTag("node", "Foo")]) 85 """ 86 seps = [] 87 tags = [] 88 pos = 0 89 while pos < len(message): 90 match = re.match(_INTERPOLATION_PATTERN, message[pos:]) 91 if match: 92 seps.append(match.group(1)) 93 tags.append(_ParseTag(match.group(3), match.group(4))) 94 pos += match.end() 95 else: 96 break 97 seps.append(message[pos:]) 98 return seps, tags 99 100 101def _compute_device_summary_from_list(name, device_assignment_list, prefix=""): 102 """Return a summary of an op's device function stack. 103 104 Args: 105 name: The name of the op. 106 device_assignment_list: The op._device_assignments list. 107 prefix: An optional string prefix used before each line of the multi- 108 line string returned by this function. 109 110 Returns: 111 A multi-line string similar to: 112 Device assignments active during op 'foo' creation: 113 with tf.device(/cpu:0): <test_1.py:27> 114 with tf.device(some_func<foo.py, 123>): <test_2.py:38> 115 The first line will have no padding to its left by default. Subsequent 116 lines will have two spaces of left-padding. Use the prefix argument 117 to increase indentation. 118 """ 119 if not device_assignment_list: 120 message = "No device assignments were active during op '%s' creation." 121 message %= name 122 return prefix + message 123 124 str_list = [] 125 str_list.append( 126 "%sDevice assignments active during op '%s' creation:" % (prefix, name)) 127 128 for traceable_obj in device_assignment_list: 129 location_summary = "<{file}:{line}>".format( 130 file=traceable_obj.filename, line=traceable_obj.lineno) 131 subs = { 132 "prefix": prefix, 133 "indent": " ", 134 "dev_name": traceable_obj.obj, 135 "loc": location_summary, 136 } 137 str_list.append( 138 "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs)) 139 140 return "\n".join(str_list) 141 142 143def _compute_device_assignment_summary_from_op(op, prefix=""): 144 # pylint: disable=protected-access 145 return _compute_device_summary_from_list(op.name, op._device_assignments, 146 prefix) 147 # pylint: enable=protected-access 148 149 150def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""): 151 """Return a summary of an op's colocation stack. 152 153 Args: 154 name: The op name. 155 colocation_dict: The op._colocation_dict. 156 prefix: An optional string prefix used before each line of the multi- 157 line string returned by this function. 158 159 Returns: 160 A multi-line string similar to: 161 Node-device colocations active during op creation: 162 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 163 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38> 164 The first line will have no padding to its left by default. Subsequent 165 lines will have two spaces of left-padding. Use the prefix argument 166 to increase indentation. 167 """ 168 if not colocation_dict: 169 message = "No node-device colocations were active during op '%s' creation." 170 message %= name 171 return prefix + message 172 173 str_list = [] 174 str_list.append("%sNode-device colocations active during op '%s' creation:" % 175 (prefix, name)) 176 177 for coloc_name, location in colocation_dict.items(): 178 location_summary = "<{file}:{line}>".format( 179 file=location.filename, line=location.lineno) 180 subs = { 181 "prefix": prefix, 182 "indent": " ", 183 "name": coloc_name, 184 "loc": location_summary, 185 } 186 str_list.append( 187 "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs)) 188 189 return "\n".join(str_list) 190 191 192def _compute_colocation_summary_from_op(op, prefix=""): 193 """Fetch colocation file, line, and nesting and return a summary string.""" 194 # pylint: disable=protected-access 195 return _compute_colocation_summary_from_dict(op.name, op._colocation_dict, 196 prefix) 197 # pylint: enable=protected-access 198 199 200def _is_framework_filename(filename): 201 """Returns whether a filename should be considered a part of the framework. 202 203 A file is part of the framework if it does not match a pattern in 204 _EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in 205 _FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix. 206 207 Args: 208 filename: A filename string. 209 210 Returns: 211 Whether the filename should be considered to be internal to the 212 TensorFlow framework for the purposes of reporting errors. 213 """ 214 for pattern in _EXTERNAL_FILENAME_PATTERNS: 215 if pattern.search(filename): 216 return False 217 for pattern in _FRAMEWORK_FILENAME_PATTERNS: 218 if pattern.search(filename): 219 return True 220 for prefix in _FRAMEWORK_PATH_PREFIXES: 221 if filename.startswith(prefix): 222 return True 223 return False 224 225 226def _find_index_of_defining_frame(traceback): 227 """Return index in op.traceback with first 'useful' frame. 228 229 This method reads through the stack stored in op.traceback looking for the 230 innermost frame which (hopefully) belongs to the caller. It accomplishes this 231 by rejecting frames deemed to be part of the TensorFlow framework (by 232 pattern matching the filename). 233 234 Args: 235 traceback: A list of traceback frames (as from Operation.traceback). 236 237 Returns: 238 Integer index into op.traceback where the first non-TF file was found 239 (innermost to outermost), or 0 (for the outermost stack frame) if all files 240 came from TensorFlow. 241 """ 242 # Index 0 of traceback is the outermost frame. 243 size = len(traceback) 244 filenames = [frame.filename for frame in traceback] 245 # We process the filenames from the innermost frame to outermost. 246 for idx, filename in enumerate(reversed(filenames)): 247 is_framework = _is_framework_filename(filename) 248 if not is_framework: 249 # Consider this to be the defining frame. 250 return size - idx - 1 251 return 0 252 253 254def _get_defining_frame(traceback): 255 """Find and return stack frame where op was defined.""" 256 frame_index = _find_index_of_defining_frame(traceback) 257 return traceback[frame_index] 258 259 260def _compute_useful_frames(traceback, num): 261 """Return a list of frames, which form a 'useful' stack. 262 263 Starting from the defining frame to the outermost one, this method computes 264 the contiguous portion of the 'useful' stack trace and returns the selected 265 frames. 266 267 Args: 268 traceback: A list of traceback frames (as from Operation.traceback). 269 num: total number of frames to return. 270 271 Returns: 272 A list of frames. 273 """ 274 defining_frame_index = _find_index_of_defining_frame(traceback) 275 # The stack trace is collected from two lines before the defining frame in the 276 # model file to the outermost with `num` frames at most. These two extra lines 277 # are included from the TensorFlow library to give the context which node is 278 # defined. 279 innermost_excluded = min(defining_frame_index + 2 + 1, len(traceback)) 280 outermost_included = max(innermost_excluded - num, 0) 281 return traceback[outermost_included:innermost_excluded] 282 283 284def create_graph_debug_info_def(func_named_operations): 285 """Construct and returns a `GraphDebugInfo` protocol buffer. 286 287 Args: 288 func_named_operations: An iterable of (func_name, op.Operation) tuples 289 where the Operation instances have a _traceback members. The func_name 290 should be the empty string for operations in the top-level Graph. 291 292 Returns: 293 GraphDebugInfo protocol buffer. 294 295 Raises: 296 TypeError: If the arguments are not of the correct proto buffer type. 297 """ 298 # Creates an empty GraphDebugInfoDef proto. 299 graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo() 300 301 # Gets the file names and line numbers for the exported node names. Also 302 # collects the unique file names. 303 all_file_names = set() 304 node_to_trace = {} 305 for func_name, op in func_named_operations: 306 try: 307 op_traceback = op.traceback 308 except AttributeError: 309 # Some ops synthesized on as part of function or control flow definition 310 # do not have tracebacks. 311 continue 312 313 # Gets the stack trace of the operation and then the file location. 314 node_name = op.name + "@" + func_name 315 node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10) 316 for frame in node_to_trace[node_name]: 317 all_file_names.add(frame.filename) 318 319 # Sets the `files` field in the GraphDebugInfo proto 320 graph_debug_info_def.files.extend(all_file_names) 321 322 # Builds a mapping between file names and index of the `files` field, so we 323 # only store the indexes for the nodes in the GraphDebugInfo. 324 file_to_index = dict( 325 [(y, x) for x, y in enumerate(graph_debug_info_def.files)]) 326 327 # Creates the FileLineCol proto for each node and sets the value in the 328 # GraphDebugInfo proto. We only store the file name index for each node to 329 # save the storage space. 330 for node_name, frames in node_to_trace.items(): 331 trace_def = graph_debug_info_def.traces[node_name] 332 for frame in reversed(frames): 333 trace_def.file_line_cols.add( 334 file_index=file_to_index[frame.filename], 335 line=frame.lineno) 336 337 return graph_debug_info_def 338 339 340def _compute_field_dict(op, strip_file_prefix=""): 341 """Return a dictionary mapping interpolation tokens to values. 342 343 Args: 344 op: op.Operation object having a _traceback member. 345 strip_file_prefix: The common path in the stacktrace. We remove the prefix 346 from the file names. 347 348 Returns: 349 A dictionary mapping string tokens to string values. The keys are shown 350 below along with example values. 351 { 352 "file": "tool_utils.py", 353 "line": "124", 354 "defined_at": " (defined at tool_utils.py:124)", 355 "colocations": 356 '''Node-device colocations active during op creation: 357 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 358 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>''' 359 "devices": 360 '''Device assignments active during op 'foo' creation: 361 with tf.device(/cpu:0): <test_1.py:27> 362 with tf.device(some_func<foo.py, 123>): <test_2.py:38>''' 363 "devs_and_colocs": A concatenation of colocations and devices, e.g. 364 '''Node-device colocations active during op creation: 365 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 366 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>''' 367 Device assignments active during op 'foo' creation: 368 with tf.device(/cpu:0): <test_1.py:27> 369 with tf.device(some_func<foo.py, 123>): <test_2.py:38>''' 370 } 371 """ 372 colocation_summary = _compute_colocation_summary_from_op(op) 373 device_summary = _compute_device_assignment_summary_from_op(op) 374 combined_summary = "\n".join([colocation_summary, device_summary]) 375 376 # Optional traceback info. 377 try: 378 traceback = op.traceback 379 except AttributeError: 380 # Some ops synthesized on as part of function or control flow definition 381 # do not have tracebacks. 382 filename = "<unknown>" 383 lineno = 0 384 defined_at = " (defined at <unknown>)" 385 else: 386 frame = _get_defining_frame(traceback) 387 filename = frame.filename 388 if filename.startswith(strip_file_prefix): 389 filename = filename[len(strip_file_prefix):] 390 lineno = frame.lineno 391 defined_at = " (defined at %s:%d)" % (filename, lineno) 392 393 field_dict = { 394 "colocations": colocation_summary, 395 "devices": device_summary, 396 "devs_and_colocs": combined_summary, 397 "defined_at": defined_at, 398 "file": filename, 399 "line": lineno, 400 } 401 return field_dict 402 403 404def traceback_files_common_prefix(all_ops): 405 """Determines the common prefix from the paths of the stacktrace of 'all_ops'. 406 407 For example, if the paths are '/foo/bar/baz/' and '/foo/car', this would 408 return '/foo'. 409 410 Args: 411 all_ops: All the input nodes in the form of a list of lists of ops. 412 413 Returns: 414 The common prefix. 415 """ 416 files = set() 417 for ops in all_ops: 418 if ops is None: 419 continue 420 for op in ops: 421 # TODO(slebedev): switch to .filename once 2.X support is dropped. 422 for filename, _, _, _ in op.traceback: 423 if "<embedded" not in filename: 424 files.add(filename) 425 return os.path.split(os.path.commonprefix(list(files)))[0] 426 427 428def _sources_for_node(node, graph): 429 """Gets the input op nodes for 'node'. 430 431 Args: 432 node: The node. 433 graph: The graph containing the node. 434 435 Returns: 436 The unique input nodes. 437 """ 438 inputs = set() 439 for name in node.node_def.input: 440 if name.startswith("^"): 441 name = name[1:] 442 try: 443 tensor = graph.get_tensor_by_name(name) 444 op = tensor.op 445 except (KeyError, ValueError): 446 try: 447 op = graph.get_operation_by_name(name) 448 except KeyError: 449 continue 450 inputs.add(op) 451 452 return list(inputs) 453 454 455def _build_error_message(op, input_ops, common_prefix): 456 """Returns the formatted error message for the given op. 457 458 Args: 459 op: The node. 460 input_ops: The input nodes to the 'op' node 461 common_prefix: The prefix path common to the stacktrace of inputs. 462 463 Returns: 464 The formatted error message for the given op. The error message also 465 includes the information about the input sources for the given op. 466 """ 467 field_dict = _compute_field_dict(op, common_prefix) 468 msg = "node %s%s " % (op.name, field_dict["defined_at"]) 469 input_debug_info = [] 470 # This stores the line numbers that we have already printed. 471 done = set() 472 done.add(field_dict["defined_at"]) 473 for op_inp in input_ops: 474 field_dict_inp = _compute_field_dict(op_inp, common_prefix) 475 if field_dict_inp["defined_at"] not in done: 476 input_debug_info.append( 477 " %s%s" % (op_inp.name, field_dict_inp["defined_at"])) 478 done.add(field_dict_inp["defined_at"]) 479 if input_debug_info: 480 end_msg = ("\nInput Source operations connected to node %s:\n") % (op.name) 481 end_msg += "\t\n".join(input_debug_info) 482 else: 483 end_msg = "" 484 return msg, end_msg 485 486 487def interpolate(error_message, graph): 488 """Interpolates an error message. 489 490 The error message can contain tags of the form `{{type name}}` which will be 491 replaced. For example: "{{node <name>}}" would get expanded to: 492 "node <name>(defined at <path>)". 493 494 Args: 495 error_message: A string to interpolate. 496 graph: ops.Graph object containing all nodes referenced in the error 497 message. 498 499 Returns: 500 The string with tags of the form {{type name}} interpolated. 501 """ 502 seps, tags = parse_message(error_message) 503 subs = [] 504 end_msg = collections.defaultdict(list) 505 tagged_ops = [] 506 507 for t in tags: 508 try: 509 op = graph.get_operation_by_name(t.name) 510 except KeyError: 511 op = None 512 if op is None: 513 tagged_ops.append(None) 514 else: 515 tagged_ops.append([op] + _sources_for_node(op, graph)) 516 517 common_prefix = traceback_files_common_prefix(tagged_ops) 518 for tag, ops in zip(tags, tagged_ops): 519 msg = "{{%s %s}}" % (tag.type, tag.name) 520 if ops is not None: 521 if tag.type == "node": 522 msg, source_msg = _build_error_message(ops[0], ops[1:], common_prefix) 523 if source_msg: 524 end_msg["source_nodes"].append(source_msg) 525 elif tag.type == "colocation_node": 526 field_dict = _compute_field_dict(ops[0], common_prefix) 527 msg = "node %s%s placed on device %s " % ( 528 ops[0].name, field_dict["defined_at"], field_dict["devices"]) 529 end_msg["colocations"].append(field_dict["devs_and_colocs"]) 530 if tag.type == "function_node": 531 msg = "" 532 subs.append(msg) 533 534 if "source_nodes" in end_msg: 535 subs.append("\n\nErrors may have originated from an input operation.") 536 subs.append("\n".join(end_msg["source_nodes"])) 537 end_msg.pop("source_nodes", None) 538 for k, messages in end_msg.items(): 539 subs.append("Additional information about %s:" % k) 540 subs.append("\n".join(messages)) 541 542 return "".join( 543 itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue=""))) 544