1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Function for interpolating formatted errors from the TensorFlow runtime. 16 17Exposes the function `interpolate` to interpolate messages with tags of the form 18{{type name}}. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import itertools 27import os 28import re 29import site 30import traceback 31 32import six 33 34from tensorflow.core.protobuf import graph_debug_info_pb2 35 36_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?" 37_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX) 38_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) 39_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL) 40 41_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"]) 42 43 44# Remove the last three path components from this module's file (i.e. 45# python/framework/error_interpolation.py) so that we have an absolute path 46# prefix to the root of the installation. 47_FRAMEWORK_COMMON_PREFIX = os.path.dirname( 48 os.path.dirname(os.path.dirname(__file__))) 49 50# Sub-directories under the common prefix that are considered part of the 51# framework. 52# Note that keras code lives outside of tensorflow directory, we need to walk 53# up the directory tree and find it. 54_FRAMEWORK_PATH_PREFIXES = [ 55 os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep, 56 os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep, 57 os.path.join(os.path.dirname(_FRAMEWORK_COMMON_PREFIX), 58 "py", "keras") + os.sep, 59] 60 61# Patterns of filename patterns that should be considered internal to 62# the TensorFlow framework. 63_FRAMEWORK_FILENAME_PATTERNS = [ 64 re.compile(r"<embedded"), 65] 66 67# This is for OSS keras, since the package is load from local python env, 68# but we don't know exactly where it is installed. Matching to keyword 69# "keras". 70try: 71 _FRAMEWORK_PATH_PREFIXES.extend([ 72 os.path.join(package_path, "keras") + os.sep 73 for package_path in site.getsitepackages() + [site.getusersitepackages()] 74 ]) 75except AttributeError: 76 # if site.getsitepackages is not available somehow, we just use the "keras" as 77 # the keyword to do the match. 78 _FRAMEWORK_FILENAME_PATTERNS.append(re.compile(r"keras")) 79 80# Patterns of filename patterns that should be considered external to 81# TensorFlow regardless of framework prefix match. 82_EXTERNAL_FILENAME_PATTERNS = [ 83 # Explicitly treat test frames as not part of the framework. 84 re.compile(r"_test\.py$"), 85] 86 87 88def parse_message(message): 89 """Parses the message. 90 91 Splits the message into separators and tags. Tags are named tuples 92 representing the string {{type name}} and they are separated by 93 separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are 94 two tags and three separators. The separators are the numeric characters. 95 96 Args: 97 message: String to parse 98 99 Returns: 100 (list of separator strings, list of _ParseTags). 101 102 For example, if message is "123{{node Foo}}456" then this function 103 returns (["123", "456"], [_ParseTag("node", "Foo")]) 104 """ 105 seps = [] 106 tags = [] 107 pos = 0 108 while pos < len(message): 109 match = re.match(_INTERPOLATION_PATTERN, message[pos:]) 110 if match: 111 seps.append(match.group(1)) 112 tags.append(_ParseTag(match.group(3), match.group(4))) 113 pos += match.end() 114 else: 115 break 116 seps.append(message[pos:]) 117 return seps, tags 118 119 120def _compute_device_summary_from_list(name, device_assignment_list, prefix=""): 121 """Return a summary of an op's device function stack. 122 123 Args: 124 name: The name of the op. 125 device_assignment_list: The op._device_assignments list. 126 prefix: An optional string prefix used before each line of the multi- 127 line string returned by this function. 128 129 Returns: 130 A multi-line string similar to: 131 Device assignments active during op 'foo' creation: 132 with tf.device(/cpu:0): <test_1.py:27> 133 with tf.device(some_func<foo.py, 123>): <test_2.py:38> 134 The first line will have no padding to its left by default. Subsequent 135 lines will have two spaces of left-padding. Use the prefix argument 136 to increase indentation. 137 """ 138 if not device_assignment_list: 139 message = "No device assignments were active during op '%s' creation." 140 message %= name 141 return prefix + message 142 143 str_list = [] 144 str_list.append( 145 "%sDevice assignments active during op '%s' creation:" % (prefix, name)) 146 147 for traceable_obj in device_assignment_list: 148 location_summary = "<{file}:{line}>".format( 149 file=traceable_obj.filename, line=traceable_obj.lineno) 150 subs = { 151 "prefix": prefix, 152 "indent": " ", 153 "dev_name": traceable_obj.obj, 154 "loc": location_summary, 155 } 156 str_list.append( 157 "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs)) 158 159 return "\n".join(str_list) 160 161 162def _compute_device_assignment_summary_from_op(op, prefix=""): 163 # pylint: disable=protected-access 164 return _compute_device_summary_from_list(op.name, op._device_assignments, 165 prefix) 166 # pylint: enable=protected-access 167 168 169def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""): 170 """Return a summary of an op's colocation stack. 171 172 Args: 173 name: The op name. 174 colocation_dict: The op._colocation_dict. 175 prefix: An optional string prefix used before each line of the multi- 176 line string returned by this function. 177 178 Returns: 179 A multi-line string similar to: 180 Node-device colocations active during op creation: 181 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 182 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38> 183 The first line will have no padding to its left by default. Subsequent 184 lines will have two spaces of left-padding. Use the prefix argument 185 to increase indentation. 186 """ 187 if not colocation_dict: 188 message = "No node-device colocations were active during op '%s' creation." 189 message %= name 190 return prefix + message 191 192 str_list = [] 193 str_list.append("%sNode-device colocations active during op '%s' creation:" % 194 (prefix, name)) 195 196 for coloc_name, location in colocation_dict.items(): 197 location_summary = "<{file}:{line}>".format( 198 file=location.filename, line=location.lineno) 199 subs = { 200 "prefix": prefix, 201 "indent": " ", 202 "name": coloc_name, 203 "loc": location_summary, 204 } 205 str_list.append( 206 "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs)) 207 208 return "\n".join(str_list) 209 210 211def _compute_colocation_summary_from_op(op, prefix=""): 212 """Fetch colocation file, line, and nesting and return a summary string.""" 213 # pylint: disable=protected-access 214 return _compute_colocation_summary_from_dict(op.name, op._colocation_dict, 215 prefix) 216 # pylint: enable=protected-access 217 218 219def _is_framework_filename(filename): 220 """Returns whether a filename should be considered a part of the framework. 221 222 A file is part of the framework if it does not match a pattern in 223 _EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in 224 _FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix. 225 226 Args: 227 filename: A filename string. 228 229 Returns: 230 Whether the filename should be considered to be internal to the 231 TensorFlow framework for the purposes of reporting errors. 232 """ 233 for pattern in _EXTERNAL_FILENAME_PATTERNS: 234 if pattern.search(filename): 235 return False 236 for pattern in _FRAMEWORK_FILENAME_PATTERNS: 237 if pattern.search(filename): 238 return True 239 for prefix in _FRAMEWORK_PATH_PREFIXES: 240 if filename.startswith(prefix): 241 return True 242 return False 243 244 245def _find_index_of_defining_frame(tb): 246 """Return index in op.traceback with first 'useful' frame. 247 248 This method reads through the stack stored in op.traceback looking for the 249 innermost frame which (hopefully) belongs to the caller. It accomplishes this 250 by rejecting frames deemed to be part of the TensorFlow framework (by 251 pattern matching the filename). 252 253 Args: 254 tb: A list of traceback frames (as from Operation.traceback). 255 256 Returns: 257 Integer index into op.traceback where the first non-TF file was found 258 (innermost to outermost), or 0 (for the outermost stack frame) if all files 259 came from TensorFlow. 260 """ 261 # Index 0 of traceback is the outermost frame. 262 size = len(tb) 263 filenames = [frame.filename for frame in tb] 264 # We process the filenames from the innermost frame to outermost. 265 for idx, filename in enumerate(reversed(filenames)): 266 is_framework = _is_framework_filename(filename) 267 if not is_framework: 268 # Consider this to be the defining frame. 269 return size - idx - 1 270 return 0 271 272 273# TODO(feyu): follow up with users of this function (saved model) 274# to see what 'useful' means and whether we can obliviate this. 275def _compute_useful_frames(tb, num): 276 """Return a list of frames, which form a 'useful' stack. 277 278 Starting from the defining frame to the outermost one, this method computes 279 the contiguous portion of the 'useful' stack trace and returns the selected 280 frames. 281 282 Args: 283 tb: A list of traceback frames (as from Operation.traceback). 284 num: total number of frames to return. 285 286 Returns: 287 A list of frames. 288 """ 289 defining_frame_index = _find_index_of_defining_frame(tb) 290 # The stack trace is collected from two lines before the defining frame in the 291 # model file to the outermost with `num` frames at most. These two extra lines 292 # are included from the TensorFlow library to give the context which node is 293 # defined. 294 innermost_excluded = min(defining_frame_index + 2 + 1, len(tb)) 295 outermost_included = max(innermost_excluded - num, 0) 296 return tb[outermost_included:innermost_excluded] 297 298 299def create_graph_debug_info_def(func_named_operations): 300 """Construct and returns a `GraphDebugInfo` protocol buffer. 301 302 Args: 303 func_named_operations: An iterable of (func_name, op.Operation) tuples 304 where the Operation instances have a _traceback members. The func_name 305 should be the empty string for operations in the top-level Graph. 306 307 Returns: 308 GraphDebugInfo protocol buffer. 309 310 Raises: 311 TypeError: If the arguments are not of the correct proto buffer type. 312 """ 313 # Creates an empty GraphDebugInfoDef proto. 314 graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo() 315 316 # Gets the file names and line numbers for the exported node names. Also 317 # collects the unique file names. 318 all_file_names = set() 319 node_to_trace = {} 320 for func_name, op in func_named_operations: 321 try: 322 op_traceback = op.traceback 323 except AttributeError: 324 # Some ops synthesized on as part of function or control flow definition 325 # do not have tracebacks. 326 continue 327 328 # Gets the stack trace of the operation and then the file location. 329 node_name = op.name + "@" + func_name 330 node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10) 331 for frame in node_to_trace[node_name]: 332 all_file_names.add(frame.filename) 333 334 # Sets the `files` field in the GraphDebugInfo proto 335 graph_debug_info_def.files.extend(all_file_names) 336 337 # Builds a mapping between file names and index of the `files` field, so we 338 # only store the indexes for the nodes in the GraphDebugInfo. 339 file_to_index = dict( 340 [(y, x) for x, y in enumerate(graph_debug_info_def.files)]) 341 342 # Creates the FileLineCol proto for each node and sets the value in the 343 # GraphDebugInfo proto. We only store the file name index for each node to 344 # save the storage space. 345 for node_name, frames in node_to_trace.items(): 346 trace_def = graph_debug_info_def.traces[node_name] 347 for frame in reversed(frames): 348 trace_def.file_line_cols.add( 349 file_index=file_to_index[frame.filename], 350 line=frame.lineno) 351 352 return graph_debug_info_def 353 354 355def _compute_field_dict(op): 356 r"""Return a dictionary mapping interpolation tokens to values. 357 358 Args: 359 op: op.Operation object having a _traceback member. 360 361 Returns: 362 A dictionary mapping string tokens to string values. The keys are shown 363 below along with example values. 364 { 365 "file": "tool_utils.py", 366 "lineno": "124", 367 "line": " source code line", 368 "defined_at": " (defined at tool_utils.py:124)", 369 "colocations": 370 '''Node-device colocations active during op creation: 371 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 372 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>''' 373 "devices": 374 '''Device assignments active during op 'foo' creation: 375 with tf.device(/cpu:0): <test_1.py:27> 376 with tf.device(some_func<foo.py, 123>): <test_2.py:38>''' 377 "devs_and_colocs": A concatenation of colocations and devices, e.g. 378 '''Node-device colocations active during op creation: 379 with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27> 380 with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>''' 381 Device assignments active during op 'foo' creation: 382 with tf.device(/cpu:0): <test_1.py:27> 383 with tf.device(some_func<foo.py, 123>): <test_2.py:38>''' 384 } 385 """ 386 colocation_summary = _compute_colocation_summary_from_op(op) 387 device_summary = _compute_device_assignment_summary_from_op(op) 388 combined_summary = "\n".join([colocation_summary, device_summary]) 389 390 # Optional traceback info. 391 try: 392 tb = op.traceback 393 except AttributeError: 394 # Some ops synthesized on as part of function or control flow definition 395 # do not have tracebacks. 396 filename = "<unknown>" 397 lineno = 0 398 defined_at = " (defined at <unknown>)" 399 definition_traceback = "" 400 line = "" 401 else: 402 frame = tb.last_user_frame() 403 filename = frame.filename 404 definition_traceback = traceback.format_list(tb.get_user_frames()) 405 406 lineno = frame.lineno 407 line = frame.line 408 defined_at = " (defined at %s:%d)" % (filename, lineno) 409 410 field_dict = { 411 "colocations": colocation_summary, 412 "devices": device_summary, 413 "devs_and_colocs": combined_summary, 414 "defined_at": defined_at, 415 "file": filename, 416 "lineno": lineno, 417 "line": line, 418 "definition_traceback": definition_traceback, 419 } 420 return field_dict 421 422 423def _get_input_ops_for_op(op, graph): 424 """Gets the input ops for op. 425 426 We do a best effort and may not always find all input Ops. 427 428 Args: 429 op: The op node. 430 graph: The graph containing the node. 431 432 Returns: 433 A list of (ind_inp, op_inp). 434 ind_inp: index in the input list. 435 op_inp: op_inp, the input Op at ind_inp in the input list. 436 """ 437 inputs = [] 438 for ind_inp, name in enumerate(op.node_def.input): 439 if name.startswith("^"): 440 name = name[1:] 441 try: 442 tensor = graph.get_tensor_by_name(name) 443 op_inp = tensor.op 444 except (KeyError, ValueError): 445 try: 446 op_inp = graph.get_operation_by_name(name) 447 except KeyError: 448 continue 449 inputs.append((ind_inp, op_inp)) 450 451 return inputs 452 453 454def _build_error_message(op, input_ops): 455 """Returns the formatted error message for the given op. 456 457 Args: 458 op: The node. 459 input_ops: The input nodes to the 'op' node 460 461 Returns: 462 The formatted error message for the given op. The error message also 463 includes the information about the input sources for the given op. 464 """ 465 field_dict = _compute_field_dict(op) 466 msg = "node %s\n%s\n" % (op.name, field_dict["defined_at"]) 467 input_debug_info = [] 468 # This stores the line numbers that we have already printed. 469 done = set() 470 done.add(field_dict["defined_at"]) 471 for ind_inp, op_inp in input_ops: 472 field_dict_inp = _compute_field_dict(op_inp) 473 if field_dict_inp["defined_at"] not in done: 474 input_debug_info.append( 475 "In[%d] %s%s" % (ind_inp, op_inp.name, field_dict_inp["defined_at"])) 476 done.add(field_dict_inp["defined_at"]) 477 else: 478 input_debug_info.append("In[%d] %s:" % (ind_inp, op_inp.name)) 479 480 end_msg = "" 481 if input_debug_info: 482 end_msg += ("\nInput Source operations connected to node %s:\n") % (op.name) 483 end_msg += "\t\n".join(input_debug_info) 484 485 end_msg += "\n\nOperation defined at: (most recent call last)\n" 486 487 definition_traceback = "\n".join(field_dict["definition_traceback"]) 488 # Adds a prefix to differentiate from a Python Interpreter traceback. 489 end_msg += "\n".join([">>> " + s for s in definition_traceback.split("\n")]) 490 491 return msg, end_msg 492 493 494def interpolate(error_message, graph): 495 """Interpolates an error message. 496 497 The error message can contain tags of the form `{{type name}}` which will be 498 replaced. For example: "{{node <name>}}" would get expanded to: 499 "node <name>(defined at <path>)". 500 501 Args: 502 error_message: A string to interpolate. 503 graph: ops.Graph object containing all nodes referenced in the error 504 message. 505 506 Returns: 507 The string with tags of the form {{type name}} interpolated. 508 """ 509 seps, tags = parse_message(error_message) 510 subs = [] 511 end_msg = collections.defaultdict(list) 512 tagged_ops = [] 513 all_ops = [] 514 515 for t in tags: 516 try: 517 op = graph.get_operation_by_name(t.name) 518 except KeyError: 519 op = None 520 if op is None: 521 tagged_ops.append((None, None)) 522 else: 523 op_inps = _get_input_ops_for_op(op, graph) 524 tagged_ops.append((op, op_inps)) 525 for _, op_inp in op_inps: 526 all_ops.append(op_inp) 527 528 for tag, (op, op_inps), in zip(tags, tagged_ops): 529 msg = "{{%s %s}}" % (tag.type, tag.name) 530 if op is not None: 531 if tag.type == "node": 532 msg, source_msg = _build_error_message(op, op_inps) 533 if source_msg: 534 end_msg["source_nodes"].append(source_msg) 535 elif tag.type == "colocation_node": 536 field_dict = _compute_field_dict(op) 537 msg = "node %s%s placed on device %s " % ( 538 op.name, field_dict["defined_at"], field_dict["devices"]) 539 end_msg["colocations"].append(field_dict["devs_and_colocs"]) 540 if tag.type == "function_node": 541 msg = "" 542 subs.append(msg) 543 544 if "source_nodes" in end_msg: 545 subs.append("\n\nErrors may have originated from an input operation.") 546 subs.append("\n".join(end_msg["source_nodes"])) 547 end_msg.pop("source_nodes", None) 548 for k, messages in end_msg.items(): 549 subs.append("Additional information about %s:" % k) 550 subs.append("\n".join(messages)) 551 552 return "".join( 553 itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue=""))) 554