1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""CLI Backend for the Analyzer Part of the Debugger. 16 17The analyzer performs post hoc analysis of dumped intermediate tensors and 18graph structure information from debugged Session.run() calls. 19""" 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import argparse 25import copy 26import re 27 28from six.moves import xrange # pylint: disable=redefined-builtin 29 30from tensorflow.python.debug.cli import cli_config 31from tensorflow.python.debug.cli import cli_shared 32from tensorflow.python.debug.cli import command_parser 33from tensorflow.python.debug.cli import debugger_cli_common 34from tensorflow.python.debug.cli import evaluator 35from tensorflow.python.debug.cli import ui_factory 36from tensorflow.python.debug.lib import debug_graphs 37from tensorflow.python.debug.lib import source_utils 38 39RL = debugger_cli_common.RichLine 40 41# String constants for the depth-dependent hanging indent at the beginning 42# of each line. 43HANG_UNFINISHED = "| " # Used for unfinished recursion depths. 44HANG_FINISHED = " " 45HANG_SUFFIX = "|- " 46 47# String constant for displaying depth and op type. 48DEPTH_TEMPLATE = "(%d) " 49OP_TYPE_TEMPLATE = "[%s] " 50 51# String constants for control inputs/outputs, etc. 52CTRL_LABEL = "(Ctrl) " 53ELLIPSIS = "..." 54 55SORT_TENSORS_BY_TIMESTAMP = "timestamp" 56SORT_TENSORS_BY_DUMP_SIZE = "dump_size" 57SORT_TENSORS_BY_OP_TYPE = "op_type" 58SORT_TENSORS_BY_TENSOR_NAME = "tensor_name" 59 60 61def _add_main_menu(output, 62 node_name=None, 63 enable_list_tensors=True, 64 enable_node_info=True, 65 enable_print_tensor=True, 66 enable_list_inputs=True, 67 enable_list_outputs=True): 68 """Generate main menu for the screen output from a command. 69 70 Args: 71 output: (debugger_cli_common.RichTextLines) the output object to modify. 72 node_name: (str or None) name of the node involved (if any). If None, 73 the menu items node_info, list_inputs and list_outputs will be 74 automatically disabled, overriding the values of arguments 75 enable_node_info, enable_list_inputs and enable_list_outputs. 76 enable_list_tensors: (bool) whether the list_tensor menu item will be 77 enabled. 78 enable_node_info: (bool) whether the node_info item will be enabled. 79 enable_print_tensor: (bool) whether the print_tensor item will be enabled. 80 enable_list_inputs: (bool) whether the item list_inputs will be enabled. 81 enable_list_outputs: (bool) whether the item list_outputs will be enabled. 82 """ 83 84 menu = debugger_cli_common.Menu() 85 86 menu.append( 87 debugger_cli_common.MenuItem( 88 "list_tensors", "list_tensors", enabled=enable_list_tensors)) 89 90 if node_name: 91 menu.append( 92 debugger_cli_common.MenuItem( 93 "node_info", 94 "node_info -a -d -t %s" % node_name, 95 enabled=enable_node_info)) 96 menu.append( 97 debugger_cli_common.MenuItem( 98 "print_tensor", 99 "print_tensor %s" % node_name, 100 enabled=enable_print_tensor)) 101 menu.append( 102 debugger_cli_common.MenuItem( 103 "list_inputs", 104 "list_inputs -c -r %s" % node_name, 105 enabled=enable_list_inputs)) 106 menu.append( 107 debugger_cli_common.MenuItem( 108 "list_outputs", 109 "list_outputs -c -r %s" % node_name, 110 enabled=enable_list_outputs)) 111 else: 112 menu.append( 113 debugger_cli_common.MenuItem( 114 "node_info", None, enabled=False)) 115 menu.append( 116 debugger_cli_common.MenuItem("print_tensor", None, enabled=False)) 117 menu.append( 118 debugger_cli_common.MenuItem("list_inputs", None, enabled=False)) 119 menu.append( 120 debugger_cli_common.MenuItem("list_outputs", None, enabled=False)) 121 122 menu.append( 123 debugger_cli_common.MenuItem("run_info", "run_info")) 124 menu.append( 125 debugger_cli_common.MenuItem("help", "help")) 126 127 output.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu 128 129 130class DebugAnalyzer(object): 131 """Analyzer for debug data from dump directories.""" 132 133 _TIMESTAMP_COLUMN_HEAD = "t (ms)" 134 _DUMP_SIZE_COLUMN_HEAD = "Size (B)" 135 _OP_TYPE_COLUMN_HEAD = "Op type" 136 _TENSOR_NAME_COLUMN_HEAD = "Tensor name" 137 138 # Op types to be omitted when generating descriptions of graph structure. 139 _GRAPH_STRUCT_OP_TYPE_DENYLIST = ("_Send", "_Recv", "_HostSend", "_HostRecv", 140 "_Retval") 141 142 def __init__(self, debug_dump, config): 143 """DebugAnalyzer constructor. 144 145 Args: 146 debug_dump: A DebugDumpDir object. 147 config: A `cli_config.CLIConfig` object that carries user-facing 148 configurations. 149 """ 150 151 self._debug_dump = debug_dump 152 self._evaluator = evaluator.ExpressionEvaluator(self._debug_dump) 153 154 # Initialize tensor filters state. 155 self._tensor_filters = {} 156 157 self._build_argument_parsers(config) 158 config.set_callback("graph_recursion_depth", 159 self._build_argument_parsers) 160 161 # TODO(cais): Implement list_nodes. 162 163 def _build_argument_parsers(self, config): 164 """Build argument parsers for DebugAnalayzer. 165 166 Args: 167 config: A `cli_config.CLIConfig` object. 168 169 Returns: 170 A dict mapping command handler name to `ArgumentParser` instance. 171 """ 172 # Argument parsers for command handlers. 173 self._arg_parsers = {} 174 175 # Parser for list_tensors. 176 ap = argparse.ArgumentParser( 177 description="List dumped intermediate tensors.", 178 usage=argparse.SUPPRESS) 179 ap.add_argument( 180 "-f", 181 "--tensor_filter", 182 dest="tensor_filter", 183 type=str, 184 default="", 185 help="List only Tensors passing the filter of the specified name") 186 ap.add_argument( 187 "-fenn", 188 "--filter_exclude_node_names", 189 dest="filter_exclude_node_names", 190 type=str, 191 default="", 192 help="When applying the tensor filter, exclude node with names " 193 "matching the regular expression. Applicable only if --tensor_filter " 194 "or -f is used.") 195 ap.add_argument( 196 "-n", 197 "--node_name_filter", 198 dest="node_name_filter", 199 type=str, 200 default="", 201 help="filter node name by regex.") 202 ap.add_argument( 203 "-t", 204 "--op_type_filter", 205 dest="op_type_filter", 206 type=str, 207 default="", 208 help="filter op type by regex.") 209 ap.add_argument( 210 "-s", 211 "--sort_by", 212 dest="sort_by", 213 type=str, 214 default=SORT_TENSORS_BY_TIMESTAMP, 215 help=("the field to sort the data by: (%s | %s | %s | %s)" % 216 (SORT_TENSORS_BY_TIMESTAMP, SORT_TENSORS_BY_DUMP_SIZE, 217 SORT_TENSORS_BY_OP_TYPE, SORT_TENSORS_BY_TENSOR_NAME))) 218 ap.add_argument( 219 "-r", 220 "--reverse", 221 dest="reverse", 222 action="store_true", 223 help="sort the data in reverse (descending) order") 224 self._arg_parsers["list_tensors"] = ap 225 226 # Parser for node_info. 227 ap = argparse.ArgumentParser( 228 description="Show information about a node.", usage=argparse.SUPPRESS) 229 ap.add_argument( 230 "node_name", 231 type=str, 232 help="Name of the node or an associated tensor, e.g., " 233 "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0") 234 ap.add_argument( 235 "-a", 236 "--attributes", 237 dest="attributes", 238 action="store_true", 239 help="Also list attributes of the node.") 240 ap.add_argument( 241 "-d", 242 "--dumps", 243 dest="dumps", 244 action="store_true", 245 help="Also list dumps available from the node.") 246 ap.add_argument( 247 "-t", 248 "--traceback", 249 dest="traceback", 250 action="store_true", 251 help="Also include the traceback of the node's creation " 252 "(if available in Python).") 253 self._arg_parsers["node_info"] = ap 254 255 # Parser for list_inputs. 256 ap = argparse.ArgumentParser( 257 description="Show inputs to a node.", usage=argparse.SUPPRESS) 258 ap.add_argument( 259 "node_name", 260 type=str, 261 help="Name of the node or an output tensor from the node, e.g., " 262 "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0") 263 ap.add_argument( 264 "-c", "--control", action="store_true", help="Include control inputs.") 265 ap.add_argument( 266 "-d", 267 "--depth", 268 dest="depth", 269 type=int, 270 default=config.get("graph_recursion_depth"), 271 help="Maximum depth of recursion used when showing the input tree.") 272 ap.add_argument( 273 "-r", 274 "--recursive", 275 dest="recursive", 276 action="store_true", 277 help="Show inputs to the node recursively, i.e., the input tree.") 278 ap.add_argument( 279 "-t", 280 "--op_type", 281 action="store_true", 282 help="Show op types of input nodes.") 283 self._arg_parsers["list_inputs"] = ap 284 285 # Parser for list_outputs. 286 ap = argparse.ArgumentParser( 287 description="Show the nodes that receive the outputs of given node.", 288 usage=argparse.SUPPRESS) 289 ap.add_argument( 290 "node_name", 291 type=str, 292 help="Name of the node or an output tensor from the node, e.g., " 293 "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0") 294 ap.add_argument( 295 "-c", "--control", action="store_true", help="Include control inputs.") 296 ap.add_argument( 297 "-d", 298 "--depth", 299 dest="depth", 300 type=int, 301 default=config.get("graph_recursion_depth"), 302 help="Maximum depth of recursion used when showing the output tree.") 303 ap.add_argument( 304 "-r", 305 "--recursive", 306 dest="recursive", 307 action="store_true", 308 help="Show recipients of the node recursively, i.e., the output " 309 "tree.") 310 ap.add_argument( 311 "-t", 312 "--op_type", 313 action="store_true", 314 help="Show op types of recipient nodes.") 315 self._arg_parsers["list_outputs"] = ap 316 317 # Parser for print_tensor. 318 self._arg_parsers["print_tensor"] = ( 319 command_parser.get_print_tensor_argparser( 320 "Print the value of a dumped tensor.")) 321 322 # Parser for print_source. 323 ap = argparse.ArgumentParser( 324 description="Print a Python source file with overlaid debug " 325 "information, including the nodes (ops) or Tensors created at the " 326 "source lines.", 327 usage=argparse.SUPPRESS) 328 ap.add_argument( 329 "source_file_path", 330 type=str, 331 help="Path to the source file.") 332 ap.add_argument( 333 "-t", 334 "--tensors", 335 dest="tensors", 336 action="store_true", 337 help="Label lines with dumped Tensors, instead of ops.") 338 ap.add_argument( 339 "-m", 340 "--max_elements_per_line", 341 type=int, 342 default=10, 343 help="Maximum number of elements (ops or Tensors) to show per source " 344 "line.") 345 ap.add_argument( 346 "-b", 347 "--line_begin", 348 type=int, 349 default=1, 350 help="Print source beginning at line number (1-based.)") 351 self._arg_parsers["print_source"] = ap 352 353 # Parser for list_source. 354 ap = argparse.ArgumentParser( 355 description="List source files responsible for constructing nodes and " 356 "tensors present in the run().", 357 usage=argparse.SUPPRESS) 358 ap.add_argument( 359 "-p", 360 "--path_filter", 361 type=str, 362 default="", 363 help="Regular expression filter for file path.") 364 ap.add_argument( 365 "-n", 366 "--node_name_filter", 367 type=str, 368 default="", 369 help="Regular expression filter for node name.") 370 self._arg_parsers["list_source"] = ap 371 372 # Parser for eval. 373 ap = argparse.ArgumentParser( 374 description="""Evaluate an arbitrary expression. Can use tensor values 375 from the current debug dump. The debug tensor names should be enclosed 376 in pairs of backticks. Expressions with spaces should be enclosed in 377 a pair of double quotes or a pair of single quotes. By default, numpy 378 is imported as np and can be used in the expressions. E.g., 379 1) eval np.argmax(`Softmax:0`), 380 2) eval 'np.sum(`Softmax:0`, axis=1)', 381 3) eval "np.matmul((`output/Identity:0`/`Softmax:0`).T, `Softmax:0`)". 382 """, 383 usage=argparse.SUPPRESS) 384 ap.add_argument( 385 "expression", 386 type=str, 387 help="""Expression to be evaluated. 388 1) in the simplest case, use <node_name>:<output_slot>, e.g., 389 hidden_0/MatMul:0. 390 391 2) if the default debug op "DebugIdentity" is to be overridden, use 392 <node_name>:<output_slot>:<debug_op>, e.g., 393 hidden_0/MatMul:0:DebugNumericSummary. 394 395 3) if the tensor of the same name exists on more than one device, use 396 <device_name>:<node_name>:<output_slot>[:<debug_op>], e.g., 397 /job:worker/replica:0/task:0/gpu:0:hidden_0/MatMul:0 398 /job:worker/replica:0/task:2/cpu:0:hidden_0/MatMul:0:DebugNanCount. 399 400 4) if the tensor is executed multiple times in a given `Session.run` 401 call, specify the execution index with a 0-based integer enclose in a 402 pair of brackets at the end, e.g., 403 RNN/tanh:0[0] 404 /job:worker/replica:0/task:0/gpu:0:RNN/tanh:0[0].""") 405 ap.add_argument( 406 "-a", 407 "--all", 408 dest="print_all", 409 action="store_true", 410 help="Print the tensor in its entirety, i.e., do not use ellipses " 411 "(may be slow for large results).") 412 ap.add_argument( 413 "-w", 414 "--write_path", 415 default="", 416 help="Path of the numpy file to write the evaluation result to, " 417 "using numpy.save()") 418 self._arg_parsers["eval"] = ap 419 420 def add_tensor_filter(self, filter_name, filter_callable): 421 """Add a tensor filter. 422 423 A tensor filter is a named callable of the signature: 424 filter_callable(dump_datum, tensor), 425 426 wherein dump_datum is an instance of debug_data.DebugTensorDatum carrying 427 metadata about the dumped tensor, including tensor name, timestamps, etc. 428 tensor is the value of the dumped tensor as an numpy.ndarray object. 429 The return value of the function is a bool. 430 This is the same signature as the input argument to 431 debug_data.DebugDumpDir.find(). 432 433 Args: 434 filter_name: (str) name of the filter. Cannot be empty. 435 filter_callable: (callable) a filter function of the signature described 436 as above. 437 438 Raises: 439 ValueError: If filter_name is an empty str. 440 TypeError: If filter_name is not a str. 441 Or if filter_callable is not callable. 442 """ 443 444 if not isinstance(filter_name, str): 445 raise TypeError("Input argument filter_name is expected to be str, " 446 "but is not.") 447 448 # Check that filter_name is not an empty str. 449 if not filter_name: 450 raise ValueError("Input argument filter_name cannot be empty.") 451 452 # Check that filter_callable is callable. 453 if not callable(filter_callable): 454 raise TypeError( 455 "Input argument filter_callable is expected to be callable, " 456 "but is not.") 457 458 self._tensor_filters[filter_name] = filter_callable 459 460 def get_tensor_filter(self, filter_name): 461 """Retrieve filter function by name. 462 463 Args: 464 filter_name: Name of the filter set during add_tensor_filter() call. 465 466 Returns: 467 The callable associated with the filter name. 468 469 Raises: 470 ValueError: If there is no tensor filter of the specified filter name. 471 """ 472 473 if filter_name not in self._tensor_filters: 474 raise ValueError("There is no tensor filter named \"%s\"" % filter_name) 475 476 return self._tensor_filters[filter_name] 477 478 def get_help(self, handler_name): 479 return self._arg_parsers[handler_name].format_help() 480 481 def list_tensors(self, args, screen_info=None): 482 """Command handler for list_tensors. 483 484 List tensors dumped during debugged Session.run() call. 485 486 Args: 487 args: Command-line arguments, excluding the command prefix, as a list of 488 str. 489 screen_info: Optional dict input containing screen information such as 490 cols. 491 492 Returns: 493 Output text lines as a RichTextLines object. 494 495 Raises: 496 ValueError: If `--filter_exclude_node_names` is used without `-f` or 497 `--tensor_filter` being used. 498 """ 499 500 # TODO(cais): Add annotations of substrings for dumped tensor names, to 501 # facilitate on-screen highlighting/selection of node names. 502 _ = screen_info 503 504 parsed = self._arg_parsers["list_tensors"].parse_args(args) 505 506 output = [] 507 508 filter_strs = [] 509 if parsed.op_type_filter: 510 op_type_regex = re.compile(parsed.op_type_filter) 511 filter_strs.append("Op type regex filter: \"%s\"" % parsed.op_type_filter) 512 else: 513 op_type_regex = None 514 515 if parsed.node_name_filter: 516 node_name_regex = re.compile(parsed.node_name_filter) 517 filter_strs.append("Node name regex filter: \"%s\"" % 518 parsed.node_name_filter) 519 else: 520 node_name_regex = None 521 522 output = debugger_cli_common.RichTextLines(filter_strs) 523 output.append("") 524 525 if parsed.tensor_filter: 526 try: 527 filter_callable = self.get_tensor_filter(parsed.tensor_filter) 528 except ValueError: 529 output = cli_shared.error("There is no tensor filter named \"%s\"." % 530 parsed.tensor_filter) 531 _add_main_menu(output, node_name=None, enable_list_tensors=False) 532 return output 533 534 data_to_show = self._debug_dump.find( 535 filter_callable, 536 exclude_node_names=parsed.filter_exclude_node_names) 537 else: 538 if parsed.filter_exclude_node_names: 539 raise ValueError( 540 "The flag --filter_exclude_node_names is valid only when " 541 "the flag -f or --tensor_filter is used.") 542 543 data_to_show = self._debug_dump.dumped_tensor_data 544 545 # TODO(cais): Implement filter by lambda on tensor value. 546 547 max_timestamp_width, max_dump_size_width, max_op_type_width = ( 548 self._measure_tensor_list_column_widths(data_to_show)) 549 550 # Sort the data. 551 data_to_show = self._sort_dump_data_by( 552 data_to_show, parsed.sort_by, parsed.reverse) 553 554 output.extend( 555 self._tensor_list_column_heads(parsed, max_timestamp_width, 556 max_dump_size_width, max_op_type_width)) 557 558 dump_count = 0 559 for dump in data_to_show: 560 if node_name_regex and not node_name_regex.match(dump.node_name): 561 continue 562 563 if op_type_regex: 564 op_type = self._debug_dump.node_op_type(dump.node_name) 565 if not op_type_regex.match(op_type): 566 continue 567 568 rel_time = (dump.timestamp - self._debug_dump.t0) / 1000.0 569 dump_size_str = cli_shared.bytes_to_readable_str(dump.dump_size_bytes) 570 dumped_tensor_name = "%s:%d" % (dump.node_name, dump.output_slot) 571 op_type = self._debug_dump.node_op_type(dump.node_name) 572 573 line = "[%.3f]" % rel_time 574 line += " " * (max_timestamp_width - len(line)) 575 line += dump_size_str 576 line += " " * (max_timestamp_width + max_dump_size_width - len(line)) 577 line += op_type 578 line += " " * (max_timestamp_width + max_dump_size_width + 579 max_op_type_width - len(line)) 580 line += dumped_tensor_name 581 582 output.append( 583 line, 584 font_attr_segs=[( 585 len(line) - len(dumped_tensor_name), len(line), 586 debugger_cli_common.MenuItem("", "pt %s" % dumped_tensor_name))]) 587 dump_count += 1 588 589 if parsed.tensor_filter: 590 output.prepend([ 591 "%d dumped tensor(s) passing filter \"%s\":" % 592 (dump_count, parsed.tensor_filter) 593 ]) 594 else: 595 output.prepend(["%d dumped tensor(s):" % dump_count]) 596 597 _add_main_menu(output, node_name=None, enable_list_tensors=False) 598 return output 599 600 def _measure_tensor_list_column_widths(self, data): 601 """Determine the maximum widths of the timestamp and op-type column. 602 603 This method assumes that data is sorted in the default order, i.e., 604 by ascending timestamps. 605 606 Args: 607 data: (list of DebugTensorDaum) the data based on which the maximum 608 column widths will be determined. 609 610 Returns: 611 (int) maximum width of the timestamp column. 0 if data is empty. 612 (int) maximum width of the dump size column. 0 if data is empty. 613 (int) maximum width of the op type column. 0 if data is empty. 614 """ 615 616 max_timestamp_width = 0 617 if data: 618 max_rel_time_ms = (data[-1].timestamp - self._debug_dump.t0) / 1000.0 619 max_timestamp_width = len("[%.3f] " % max_rel_time_ms) + 1 620 max_timestamp_width = max(max_timestamp_width, 621 len(self._TIMESTAMP_COLUMN_HEAD) + 1) 622 623 max_dump_size_width = 0 624 for dump in data: 625 dump_size_str = cli_shared.bytes_to_readable_str(dump.dump_size_bytes) 626 if len(dump_size_str) + 1 > max_dump_size_width: 627 max_dump_size_width = len(dump_size_str) + 1 628 max_dump_size_width = max(max_dump_size_width, 629 len(self._DUMP_SIZE_COLUMN_HEAD) + 1) 630 631 max_op_type_width = 0 632 for dump in data: 633 op_type = self._debug_dump.node_op_type(dump.node_name) 634 if len(op_type) + 1 > max_op_type_width: 635 max_op_type_width = len(op_type) + 1 636 max_op_type_width = max(max_op_type_width, 637 len(self._OP_TYPE_COLUMN_HEAD) + 1) 638 639 return max_timestamp_width, max_dump_size_width, max_op_type_width 640 641 def _sort_dump_data_by(self, data, sort_by, reverse): 642 """Sort a list of DebugTensorDatum in specified order. 643 644 Args: 645 data: (list of DebugTensorDatum) the data to be sorted. 646 sort_by: The field to sort data by. 647 reverse: (bool) Whether to use reversed (descending) order. 648 649 Returns: 650 (list of DebugTensorDatum) in sorted order. 651 652 Raises: 653 ValueError: given an invalid value of sort_by. 654 """ 655 656 if sort_by == SORT_TENSORS_BY_TIMESTAMP: 657 return sorted( 658 data, 659 reverse=reverse, 660 key=lambda x: x.timestamp) 661 elif sort_by == SORT_TENSORS_BY_DUMP_SIZE: 662 return sorted(data, reverse=reverse, key=lambda x: x.dump_size_bytes) 663 elif sort_by == SORT_TENSORS_BY_OP_TYPE: 664 return sorted( 665 data, 666 reverse=reverse, 667 key=lambda x: self._debug_dump.node_op_type(x.node_name)) 668 elif sort_by == SORT_TENSORS_BY_TENSOR_NAME: 669 return sorted( 670 data, 671 reverse=reverse, 672 key=lambda x: "%s:%d" % (x.node_name, x.output_slot)) 673 else: 674 raise ValueError("Unsupported key to sort tensors by: %s" % sort_by) 675 676 def _tensor_list_column_heads(self, parsed, max_timestamp_width, 677 max_dump_size_width, max_op_type_width): 678 """Generate a line containing the column heads of the tensor list. 679 680 Args: 681 parsed: Parsed arguments (by argparse) of the list_tensors command. 682 max_timestamp_width: (int) maximum width of the timestamp column. 683 max_dump_size_width: (int) maximum width of the dump size column. 684 max_op_type_width: (int) maximum width of the op type column. 685 686 Returns: 687 A RichTextLines object. 688 """ 689 690 base_command = "list_tensors" 691 if parsed.tensor_filter: 692 base_command += " -f %s" % parsed.tensor_filter 693 if parsed.op_type_filter: 694 base_command += " -t %s" % parsed.op_type_filter 695 if parsed.node_name_filter: 696 base_command += " -n %s" % parsed.node_name_filter 697 698 attr_segs = {0: []} 699 row = self._TIMESTAMP_COLUMN_HEAD 700 command = "%s -s %s" % (base_command, SORT_TENSORS_BY_TIMESTAMP) 701 if parsed.sort_by == SORT_TENSORS_BY_TIMESTAMP and not parsed.reverse: 702 command += " -r" 703 attr_segs[0].append( 704 (0, len(row), [debugger_cli_common.MenuItem(None, command), "bold"])) 705 row += " " * (max_timestamp_width - len(row)) 706 707 prev_len = len(row) 708 row += self._DUMP_SIZE_COLUMN_HEAD 709 command = "%s -s %s" % (base_command, SORT_TENSORS_BY_DUMP_SIZE) 710 if parsed.sort_by == SORT_TENSORS_BY_DUMP_SIZE and not parsed.reverse: 711 command += " -r" 712 attr_segs[0].append((prev_len, len(row), 713 [debugger_cli_common.MenuItem(None, command), "bold"])) 714 row += " " * (max_dump_size_width + max_timestamp_width - len(row)) 715 716 prev_len = len(row) 717 row += self._OP_TYPE_COLUMN_HEAD 718 command = "%s -s %s" % (base_command, SORT_TENSORS_BY_OP_TYPE) 719 if parsed.sort_by == SORT_TENSORS_BY_OP_TYPE and not parsed.reverse: 720 command += " -r" 721 attr_segs[0].append((prev_len, len(row), 722 [debugger_cli_common.MenuItem(None, command), "bold"])) 723 row += " " * ( 724 max_op_type_width + max_dump_size_width + max_timestamp_width - len(row) 725 ) 726 727 prev_len = len(row) 728 row += self._TENSOR_NAME_COLUMN_HEAD 729 command = "%s -s %s" % (base_command, SORT_TENSORS_BY_TENSOR_NAME) 730 if parsed.sort_by == SORT_TENSORS_BY_TENSOR_NAME and not parsed.reverse: 731 command += " -r" 732 attr_segs[0].append((prev_len, len(row), 733 [debugger_cli_common.MenuItem("", command), "bold"])) 734 row += " " * ( 735 max_op_type_width + max_dump_size_width + max_timestamp_width - len(row) 736 ) 737 738 return debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs) 739 740 def node_info(self, args, screen_info=None): 741 """Command handler for node_info. 742 743 Query information about a given node. 744 745 Args: 746 args: Command-line arguments, excluding the command prefix, as a list of 747 str. 748 screen_info: Optional dict input containing screen information such as 749 cols. 750 751 Returns: 752 Output text lines as a RichTextLines object. 753 """ 754 755 # TODO(cais): Add annotation of substrings for node names, to facilitate 756 # on-screen highlighting/selection of node names. 757 _ = screen_info 758 759 parsed = self._arg_parsers["node_info"].parse_args(args) 760 761 # Get a node name, regardless of whether the input is a node name (without 762 # output slot attached) or a tensor name (with output slot attached). 763 node_name, unused_slot = debug_graphs.parse_node_or_tensor_name( 764 parsed.node_name) 765 766 if not self._debug_dump.node_exists(node_name): 767 output = cli_shared.error( 768 "There is no node named \"%s\" in the partition graphs" % node_name) 769 _add_main_menu( 770 output, 771 node_name=None, 772 enable_list_tensors=True, 773 enable_node_info=False, 774 enable_list_inputs=False, 775 enable_list_outputs=False) 776 return output 777 778 # TODO(cais): Provide UI glossary feature to explain to users what the 779 # term "partition graph" means and how it is related to TF graph objects 780 # in Python. The information can be along the line of: 781 # "A tensorflow graph defined in Python is stripped of unused ops 782 # according to the feeds and fetches and divided into a number of 783 # partition graphs that may be distributed among multiple devices and 784 # hosts. The partition graphs are what's actually executed by the C++ 785 # runtime during a run() call." 786 787 lines = ["Node %s" % node_name] 788 font_attr_segs = { 789 0: [(len(lines[-1]) - len(node_name), len(lines[-1]), "bold")] 790 } 791 lines.append("") 792 lines.append(" Op: %s" % self._debug_dump.node_op_type(node_name)) 793 lines.append(" Device: %s" % self._debug_dump.node_device(node_name)) 794 output = debugger_cli_common.RichTextLines( 795 lines, font_attr_segs=font_attr_segs) 796 797 # List node inputs (non-control and control). 798 inputs = self._exclude_denylisted_ops( 799 self._debug_dump.node_inputs(node_name)) 800 ctrl_inputs = self._exclude_denylisted_ops( 801 self._debug_dump.node_inputs(node_name, is_control=True)) 802 output.extend(self._format_neighbors("input", inputs, ctrl_inputs)) 803 804 # List node output recipients (non-control and control). 805 recs = self._exclude_denylisted_ops( 806 self._debug_dump.node_recipients(node_name)) 807 ctrl_recs = self._exclude_denylisted_ops( 808 self._debug_dump.node_recipients(node_name, is_control=True)) 809 output.extend(self._format_neighbors("recipient", recs, ctrl_recs)) 810 811 # Optional: List attributes of the node. 812 if parsed.attributes: 813 output.extend(self._list_node_attributes(node_name)) 814 815 # Optional: List dumps available from the node. 816 if parsed.dumps: 817 output.extend(self._list_node_dumps(node_name)) 818 819 if parsed.traceback: 820 output.extend(self._render_node_traceback(node_name)) 821 822 _add_main_menu(output, node_name=node_name, enable_node_info=False) 823 return output 824 825 def _exclude_denylisted_ops(self, node_names): 826 """Exclude all nodes whose op types are in _GRAPH_STRUCT_OP_TYPE_DENYLIST. 827 828 Args: 829 node_names: An iterable of node or graph element names. 830 831 Returns: 832 A list of node names that are not denylisted. 833 """ 834 return [ 835 node_name for node_name in node_names 836 if self._debug_dump.node_op_type(debug_graphs.get_node_name(node_name)) 837 not in self._GRAPH_STRUCT_OP_TYPE_DENYLIST 838 ] 839 840 def _render_node_traceback(self, node_name): 841 """Render traceback of a node's creation in Python, if available. 842 843 Args: 844 node_name: (str) name of the node. 845 846 Returns: 847 A RichTextLines object containing the stack trace of the node's 848 construction. 849 """ 850 851 lines = [RL(""), RL(""), RL("Traceback of node construction:", "bold")] 852 853 try: 854 node_stack = self._debug_dump.node_traceback(node_name) 855 for depth, (file_path, line, function_name, text) in enumerate( 856 node_stack): 857 lines.append("%d: %s" % (depth, file_path)) 858 859 attribute = debugger_cli_common.MenuItem( 860 "", "ps %s -b %d" % (file_path, line)) if text else None 861 line_number_line = RL(" ") 862 line_number_line += RL("Line: %d" % line, attribute) 863 lines.append(line_number_line) 864 865 lines.append(" Function: %s" % function_name) 866 lines.append(" Text: " + (("\"%s\"" % text) if text else "None")) 867 lines.append("") 868 except KeyError: 869 lines.append("(Node unavailable in the loaded Python graph)") 870 except LookupError: 871 lines.append("(Unavailable because no Python graph has been loaded)") 872 873 return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) 874 875 def list_inputs(self, args, screen_info=None): 876 """Command handler for inputs. 877 878 Show inputs to a given node. 879 880 Args: 881 args: Command-line arguments, excluding the command prefix, as a list of 882 str. 883 screen_info: Optional dict input containing screen information such as 884 cols. 885 886 Returns: 887 Output text lines as a RichTextLines object. 888 """ 889 890 # Screen info not currently used by this handler. Include this line to 891 # mute pylint. 892 _ = screen_info 893 # TODO(cais): Use screen info to format the output lines more prettily, 894 # e.g., hanging indent of long node names. 895 896 parsed = self._arg_parsers["list_inputs"].parse_args(args) 897 898 output = self._list_inputs_or_outputs( 899 parsed.recursive, 900 parsed.node_name, 901 parsed.depth, 902 parsed.control, 903 parsed.op_type, 904 do_outputs=False) 905 906 node_name = debug_graphs.get_node_name(parsed.node_name) 907 _add_main_menu(output, node_name=node_name, enable_list_inputs=False) 908 909 return output 910 911 def print_tensor(self, args, screen_info=None): 912 """Command handler for print_tensor. 913 914 Print value of a given dumped tensor. 915 916 Args: 917 args: Command-line arguments, excluding the command prefix, as a list of 918 str. 919 screen_info: Optional dict input containing screen information such as 920 cols. 921 922 Returns: 923 Output text lines as a RichTextLines object. 924 """ 925 926 parsed = self._arg_parsers["print_tensor"].parse_args(args) 927 928 np_printoptions = cli_shared.numpy_printoptions_from_screen_info( 929 screen_info) 930 931 # Determine if any range-highlighting is required. 932 highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges) 933 934 tensor_name, tensor_slicing = ( 935 command_parser.parse_tensor_name_with_slicing(parsed.tensor_name)) 936 937 node_name, output_slot = debug_graphs.parse_node_or_tensor_name(tensor_name) 938 if (self._debug_dump.loaded_partition_graphs() and 939 not self._debug_dump.node_exists(node_name)): 940 output = cli_shared.error( 941 "Node \"%s\" does not exist in partition graphs" % node_name) 942 _add_main_menu( 943 output, 944 node_name=None, 945 enable_list_tensors=True, 946 enable_print_tensor=False) 947 return output 948 949 watch_keys = self._debug_dump.debug_watch_keys(node_name) 950 if output_slot is None: 951 output_slots = set() 952 for watch_key in watch_keys: 953 output_slots.add(int(watch_key.split(":")[1])) 954 955 if len(output_slots) == 1: 956 # There is only one dumped tensor from this node, so there is no 957 # ambiguity. Proceed to show the only dumped tensor. 958 output_slot = list(output_slots)[0] 959 else: 960 # There are more than one dumped tensors from this node. Indicate as 961 # such. 962 # TODO(cais): Provide an output screen with command links for 963 # convenience. 964 lines = [ 965 "Node \"%s\" generated debug dumps from %s output slots:" % 966 (node_name, len(output_slots)), 967 "Please specify the output slot: %s:x." % node_name 968 ] 969 output = debugger_cli_common.RichTextLines(lines) 970 _add_main_menu( 971 output, 972 node_name=node_name, 973 enable_list_tensors=True, 974 enable_print_tensor=False) 975 return output 976 977 # Find debug dump data that match the tensor name (node name + output 978 # slot). 979 matching_data = [] 980 for watch_key in watch_keys: 981 debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key) 982 for datum in debug_tensor_data: 983 if datum.output_slot == output_slot: 984 matching_data.append(datum) 985 986 if not matching_data: 987 # No dump for this tensor. 988 output = cli_shared.error("Tensor \"%s\" did not generate any dumps." % 989 parsed.tensor_name) 990 elif len(matching_data) == 1: 991 # There is only one dump for this tensor. 992 if parsed.number <= 0: 993 output = cli_shared.format_tensor( 994 matching_data[0].get_tensor(), 995 matching_data[0].watch_key, 996 np_printoptions, 997 print_all=parsed.print_all, 998 tensor_slicing=tensor_slicing, 999 highlight_options=highlight_options, 1000 include_numeric_summary=parsed.numeric_summary, 1001 write_path=parsed.write_path) 1002 else: 1003 output = cli_shared.error( 1004 "Invalid number (%d) for tensor %s, which generated one dump." % 1005 (parsed.number, parsed.tensor_name)) 1006 1007 _add_main_menu(output, node_name=node_name, enable_print_tensor=False) 1008 else: 1009 # There are more than one dumps for this tensor. 1010 if parsed.number < 0: 1011 lines = [ 1012 "Tensor \"%s\" generated %d dumps:" % (parsed.tensor_name, 1013 len(matching_data)) 1014 ] 1015 font_attr_segs = {} 1016 1017 for i, datum in enumerate(matching_data): 1018 rel_time = (datum.timestamp - self._debug_dump.t0) / 1000.0 1019 lines.append("#%d [%.3f ms] %s" % (i, rel_time, datum.watch_key)) 1020 command = "print_tensor %s -n %d" % (parsed.tensor_name, i) 1021 font_attr_segs[len(lines) - 1] = [( 1022 len(lines[-1]) - len(datum.watch_key), len(lines[-1]), 1023 debugger_cli_common.MenuItem(None, command))] 1024 1025 lines.append("") 1026 lines.append( 1027 "You can use the -n (--number) flag to specify which dump to " 1028 "print.") 1029 lines.append("For example:") 1030 lines.append(" print_tensor %s -n 0" % parsed.tensor_name) 1031 1032 output = debugger_cli_common.RichTextLines( 1033 lines, font_attr_segs=font_attr_segs) 1034 elif parsed.number >= len(matching_data): 1035 output = cli_shared.error( 1036 "Specified number (%d) exceeds the number of available dumps " 1037 "(%d) for tensor %s" % 1038 (parsed.number, len(matching_data), parsed.tensor_name)) 1039 else: 1040 output = cli_shared.format_tensor( 1041 matching_data[parsed.number].get_tensor(), 1042 matching_data[parsed.number].watch_key + " (dump #%d)" % 1043 parsed.number, 1044 np_printoptions, 1045 print_all=parsed.print_all, 1046 tensor_slicing=tensor_slicing, 1047 highlight_options=highlight_options, 1048 write_path=parsed.write_path) 1049 _add_main_menu(output, node_name=node_name, enable_print_tensor=False) 1050 1051 return output 1052 1053 def list_outputs(self, args, screen_info=None): 1054 """Command handler for inputs. 1055 1056 Show inputs to a given node. 1057 1058 Args: 1059 args: Command-line arguments, excluding the command prefix, as a list of 1060 str. 1061 screen_info: Optional dict input containing screen information such as 1062 cols. 1063 1064 Returns: 1065 Output text lines as a RichTextLines object. 1066 """ 1067 1068 # Screen info not currently used by this handler. Include this line to 1069 # mute pylint. 1070 _ = screen_info 1071 # TODO(cais): Use screen info to format the output lines more prettily, 1072 # e.g., hanging indent of long node names. 1073 1074 parsed = self._arg_parsers["list_outputs"].parse_args(args) 1075 1076 output = self._list_inputs_or_outputs( 1077 parsed.recursive, 1078 parsed.node_name, 1079 parsed.depth, 1080 parsed.control, 1081 parsed.op_type, 1082 do_outputs=True) 1083 1084 node_name = debug_graphs.get_node_name(parsed.node_name) 1085 _add_main_menu(output, node_name=node_name, enable_list_outputs=False) 1086 1087 return output 1088 1089 def evaluate_expression(self, args, screen_info=None): 1090 parsed = self._arg_parsers["eval"].parse_args(args) 1091 1092 eval_res = self._evaluator.evaluate(parsed.expression) 1093 1094 np_printoptions = cli_shared.numpy_printoptions_from_screen_info( 1095 screen_info) 1096 return cli_shared.format_tensor( 1097 eval_res, 1098 "from eval of expression '%s'" % parsed.expression, 1099 np_printoptions, 1100 print_all=parsed.print_all, 1101 include_numeric_summary=True, 1102 write_path=parsed.write_path) 1103 1104 def _reconstruct_print_source_command(self, 1105 parsed, 1106 line_begin, 1107 max_elements_per_line_increase=0): 1108 return "ps %s %s -b %d -m %d" % ( 1109 parsed.source_file_path, "-t" if parsed.tensors else "", line_begin, 1110 parsed.max_elements_per_line + max_elements_per_line_increase) 1111 1112 def print_source(self, args, screen_info=None): 1113 """Print the content of a source file.""" 1114 del screen_info # Unused. 1115 1116 parsed = self._arg_parsers["print_source"].parse_args(args) 1117 1118 source_annotation = source_utils.annotate_source( 1119 self._debug_dump, 1120 parsed.source_file_path, 1121 do_dumped_tensors=parsed.tensors) 1122 1123 source_lines, line_num_width = source_utils.load_source( 1124 parsed.source_file_path) 1125 1126 labeled_source_lines = [] 1127 actual_initial_scroll_target = 0 1128 for i, line in enumerate(source_lines): 1129 annotated_line = RL("L%d" % (i + 1), cli_shared.COLOR_YELLOW) 1130 annotated_line += " " * (line_num_width - len(annotated_line)) 1131 annotated_line += line 1132 labeled_source_lines.append(annotated_line) 1133 1134 if i + 1 == parsed.line_begin: 1135 actual_initial_scroll_target = len(labeled_source_lines) - 1 1136 1137 if i + 1 in source_annotation: 1138 sorted_elements = sorted(source_annotation[i + 1]) 1139 for k, element in enumerate(sorted_elements): 1140 if k >= parsed.max_elements_per_line: 1141 omitted_info_line = RL(" (... Omitted %d of %d %s ...) " % ( 1142 len(sorted_elements) - parsed.max_elements_per_line, 1143 len(sorted_elements), 1144 "tensor(s)" if parsed.tensors else "op(s)")) 1145 omitted_info_line += RL( 1146 "+5", 1147 debugger_cli_common.MenuItem( 1148 None, 1149 self._reconstruct_print_source_command( 1150 parsed, i + 1, max_elements_per_line_increase=5))) 1151 labeled_source_lines.append(omitted_info_line) 1152 break 1153 1154 label = RL(" " * 4) 1155 if self._debug_dump.debug_watch_keys( 1156 debug_graphs.get_node_name(element)): 1157 attribute = debugger_cli_common.MenuItem("", "pt %s" % element) 1158 else: 1159 attribute = cli_shared.COLOR_BLUE 1160 1161 label += RL(element, attribute) 1162 labeled_source_lines.append(label) 1163 1164 output = debugger_cli_common.rich_text_lines_from_rich_line_list( 1165 labeled_source_lines, 1166 annotations={debugger_cli_common.INIT_SCROLL_POS_KEY: 1167 actual_initial_scroll_target}) 1168 _add_main_menu(output, node_name=None) 1169 return output 1170 1171 def _make_source_table(self, source_list, is_tf_py_library): 1172 """Make a table summarizing the source files that create nodes and tensors. 1173 1174 Args: 1175 source_list: List of source files and related information as a list of 1176 tuples (file_path, is_tf_library, num_nodes, num_tensors, num_dumps, 1177 first_line). 1178 is_tf_py_library: (`bool`) whether this table is for files that belong 1179 to the TensorFlow Python library. 1180 1181 Returns: 1182 The table as a `debugger_cli_common.RichTextLines` object. 1183 """ 1184 path_head = "Source file path" 1185 num_nodes_head = "#(nodes)" 1186 num_tensors_head = "#(tensors)" 1187 num_dumps_head = "#(tensor dumps)" 1188 1189 if is_tf_py_library: 1190 # Use color to mark files that are guessed to belong to TensorFlow Python 1191 # library. 1192 color = cli_shared.COLOR_GRAY 1193 lines = [RL("TensorFlow Python library file(s):", color)] 1194 else: 1195 color = cli_shared.COLOR_WHITE 1196 lines = [RL("File(s) outside TensorFlow Python library:", color)] 1197 1198 if not source_list: 1199 lines.append(RL("[No files.]")) 1200 lines.append(RL()) 1201 return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) 1202 1203 path_column_width = max( 1204 max(len(item[0]) for item in source_list), len(path_head)) + 1 1205 num_nodes_column_width = max( 1206 max(len(str(item[2])) for item in source_list), 1207 len(num_nodes_head)) + 1 1208 num_tensors_column_width = max( 1209 max(len(str(item[3])) for item in source_list), 1210 len(num_tensors_head)) + 1 1211 1212 head = RL(path_head + " " * (path_column_width - len(path_head)), color) 1213 head += RL(num_nodes_head + " " * ( 1214 num_nodes_column_width - len(num_nodes_head)), color) 1215 head += RL(num_tensors_head + " " * ( 1216 num_tensors_column_width - len(num_tensors_head)), color) 1217 head += RL(num_dumps_head, color) 1218 1219 lines.append(head) 1220 1221 for (file_path, _, num_nodes, num_tensors, num_dumps, 1222 first_line_num) in source_list: 1223 path_attributes = [color] 1224 if source_utils.is_extension_uncompiled_python_source(file_path): 1225 path_attributes.append( 1226 debugger_cli_common.MenuItem(None, "ps %s -b %d" % 1227 (file_path, first_line_num))) 1228 1229 line = RL(file_path, path_attributes) 1230 line += " " * (path_column_width - len(line)) 1231 line += RL( 1232 str(num_nodes) + " " * (num_nodes_column_width - len(str(num_nodes))), 1233 color) 1234 line += RL( 1235 str(num_tensors) + " " * 1236 (num_tensors_column_width - len(str(num_tensors))), color) 1237 line += RL(str(num_dumps), color) 1238 lines.append(line) 1239 lines.append(RL()) 1240 1241 return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) 1242 1243 def list_source(self, args, screen_info=None): 1244 """List Python source files that constructed nodes and tensors.""" 1245 del screen_info # Unused. 1246 1247 parsed = self._arg_parsers["list_source"].parse_args(args) 1248 source_list = source_utils.list_source_files_against_dump( 1249 self._debug_dump, 1250 path_regex_allowlist=parsed.path_filter, 1251 node_name_regex_allowlist=parsed.node_name_filter) 1252 1253 top_lines = [ 1254 RL("List of source files that created nodes in this run", "bold")] 1255 if parsed.path_filter: 1256 top_lines.append( 1257 RL("File path regex filter: \"%s\"" % parsed.path_filter)) 1258 if parsed.node_name_filter: 1259 top_lines.append( 1260 RL("Node name regex filter: \"%s\"" % parsed.node_name_filter)) 1261 top_lines.append(RL()) 1262 output = debugger_cli_common.rich_text_lines_from_rich_line_list(top_lines) 1263 if not source_list: 1264 output.append("[No source file information.]") 1265 return output 1266 1267 output.extend(self._make_source_table( 1268 [item for item in source_list if not item[1]], False)) 1269 output.extend(self._make_source_table( 1270 [item for item in source_list if item[1]], True)) 1271 _add_main_menu(output, node_name=None) 1272 return output 1273 1274 def _list_inputs_or_outputs(self, 1275 recursive, 1276 node_name, 1277 depth, 1278 control, 1279 op_type, 1280 do_outputs=False): 1281 """Helper function used by list_inputs and list_outputs. 1282 1283 Format a list of lines to display the inputs or output recipients of a 1284 given node. 1285 1286 Args: 1287 recursive: Whether the listing is to be done recursively, as a boolean. 1288 node_name: The name of the node in question, as a str. 1289 depth: Maximum recursion depth, applies only if recursive == True, as an 1290 int. 1291 control: Whether control inputs or control recipients are included, as a 1292 boolean. 1293 op_type: Whether the op types of the nodes are to be included, as a 1294 boolean. 1295 do_outputs: Whether recipients, instead of input nodes are to be 1296 listed, as a boolean. 1297 1298 Returns: 1299 Input or recipient tree formatted as a RichTextLines object. 1300 """ 1301 1302 if do_outputs: 1303 tracker = self._debug_dump.node_recipients 1304 type_str = "Recipients of" 1305 short_type_str = "recipients" 1306 else: 1307 tracker = self._debug_dump.node_inputs 1308 type_str = "Inputs to" 1309 short_type_str = "inputs" 1310 1311 lines = [] 1312 font_attr_segs = {} 1313 1314 # Check if this is a tensor name, instead of a node name. 1315 node_name, _ = debug_graphs.parse_node_or_tensor_name(node_name) 1316 1317 # Check if node exists. 1318 if not self._debug_dump.node_exists(node_name): 1319 return cli_shared.error( 1320 "There is no node named \"%s\" in the partition graphs" % node_name) 1321 1322 if recursive: 1323 max_depth = depth 1324 else: 1325 max_depth = 1 1326 1327 if control: 1328 include_ctrls_str = ", control %s included" % short_type_str 1329 else: 1330 include_ctrls_str = "" 1331 1332 line = "%s node \"%s\"" % (type_str, node_name) 1333 font_attr_segs[0] = [(len(line) - 1 - len(node_name), len(line) - 1, "bold") 1334 ] 1335 lines.append(line + " (Depth limit = %d%s):" % (max_depth, include_ctrls_str 1336 )) 1337 1338 command_template = "lo -c -r %s" if do_outputs else "li -c -r %s" 1339 self._dfs_from_node( 1340 lines, 1341 font_attr_segs, 1342 node_name, 1343 tracker, 1344 max_depth, 1345 1, [], 1346 control, 1347 op_type, 1348 command_template=command_template) 1349 1350 # Include legend. 1351 lines.append("") 1352 lines.append("Legend:") 1353 lines.append(" (d): recursion depth = d.") 1354 1355 if control: 1356 lines.append(" (Ctrl): Control input.") 1357 if op_type: 1358 lines.append(" [Op]: Input node has op type Op.") 1359 1360 # TODO(cais): Consider appending ":0" at the end of 1st outputs of nodes. 1361 1362 return debugger_cli_common.RichTextLines( 1363 lines, font_attr_segs=font_attr_segs) 1364 1365 def _dfs_from_node(self, 1366 lines, 1367 attr_segs, 1368 node_name, 1369 tracker, 1370 max_depth, 1371 depth, 1372 unfinished, 1373 include_control=False, 1374 show_op_type=False, 1375 command_template=None): 1376 """Perform depth-first search (DFS) traversal of a node's input tree. 1377 1378 It recursively tracks the inputs (or output recipients) of the node called 1379 node_name, and append these inputs (or output recipients) to a list of text 1380 lines (lines) with proper indentation that reflects the recursion depth, 1381 together with some formatting attributes (to attr_segs). The formatting 1382 attributes can include command shortcuts, for example. 1383 1384 Args: 1385 lines: Text lines to append to, as a list of str. 1386 attr_segs: (dict) Attribute segments dictionary to append to. 1387 node_name: Name of the node, as a str. This arg is updated during the 1388 recursion. 1389 tracker: A callable that takes one str as the node name input and 1390 returns a list of str as the inputs/outputs. 1391 This makes it this function general enough to be used with both 1392 node-input and node-output tracking. 1393 max_depth: Maximum recursion depth, as an int. 1394 depth: Current recursion depth. This arg is updated during the 1395 recursion. 1396 unfinished: A stack of unfinished recursion depths, as a list of int. 1397 include_control: Whether control dependencies are to be included as 1398 inputs (and marked as such). 1399 show_op_type: Whether op type of the input nodes are to be displayed 1400 alongside the nodes' names. 1401 command_template: (str) Template for command shortcut of the node names. 1402 """ 1403 1404 # Make a shallow copy of the list because it may be extended later. 1405 all_inputs = self._exclude_denylisted_ops( 1406 copy.copy(tracker(node_name, is_control=False))) 1407 is_ctrl = [False] * len(all_inputs) 1408 if include_control: 1409 # Sort control inputs or recipients in alphabetical order of the node 1410 # names. 1411 ctrl_inputs = self._exclude_denylisted_ops( 1412 sorted(tracker(node_name, is_control=True))) 1413 all_inputs.extend(ctrl_inputs) 1414 is_ctrl.extend([True] * len(ctrl_inputs)) 1415 1416 if not all_inputs: 1417 if depth == 1: 1418 lines.append(" [None]") 1419 1420 return 1421 1422 unfinished.append(depth) 1423 1424 # Create depth-dependent hanging indent for the line. 1425 hang = "" 1426 for k in xrange(depth): 1427 if k < depth - 1: 1428 if k + 1 in unfinished: 1429 hang += HANG_UNFINISHED 1430 else: 1431 hang += HANG_FINISHED 1432 else: 1433 hang += HANG_SUFFIX 1434 1435 if all_inputs and depth > max_depth: 1436 lines.append(hang + ELLIPSIS) 1437 unfinished.pop() 1438 return 1439 1440 hang += DEPTH_TEMPLATE % depth 1441 1442 for i, inp in enumerate(all_inputs): 1443 op_type = self._debug_dump.node_op_type(debug_graphs.get_node_name(inp)) 1444 if op_type in self._GRAPH_STRUCT_OP_TYPE_DENYLIST: 1445 continue 1446 1447 if is_ctrl[i]: 1448 ctrl_str = CTRL_LABEL 1449 else: 1450 ctrl_str = "" 1451 1452 op_type_str = "" 1453 if show_op_type: 1454 op_type_str = OP_TYPE_TEMPLATE % op_type 1455 1456 if i == len(all_inputs) - 1: 1457 unfinished.pop() 1458 1459 line = hang + ctrl_str + op_type_str + inp 1460 lines.append(line) 1461 if command_template: 1462 attr_segs[len(lines) - 1] = [( 1463 len(line) - len(inp), len(line), 1464 debugger_cli_common.MenuItem(None, command_template % inp))] 1465 1466 # Recursive call. 1467 # The input's/output's name can be a tensor name, in the case of node 1468 # with >1 output slots. 1469 inp_node_name, _ = debug_graphs.parse_node_or_tensor_name(inp) 1470 self._dfs_from_node( 1471 lines, 1472 attr_segs, 1473 inp_node_name, 1474 tracker, 1475 max_depth, 1476 depth + 1, 1477 unfinished, 1478 include_control=include_control, 1479 show_op_type=show_op_type, 1480 command_template=command_template) 1481 1482 def _format_neighbors(self, neighbor_type, non_ctrls, ctrls): 1483 """List neighbors (inputs or recipients) of a node. 1484 1485 Args: 1486 neighbor_type: ("input" | "recipient") 1487 non_ctrls: Non-control neighbor node names, as a list of str. 1488 ctrls: Control neighbor node names, as a list of str. 1489 1490 Returns: 1491 A RichTextLines object. 1492 """ 1493 1494 # TODO(cais): Return RichTextLines instead, to allow annotation of node 1495 # names. 1496 lines = [] 1497 font_attr_segs = {} 1498 1499 lines.append("") 1500 lines.append(" %d %s(s) + %d control %s(s):" % 1501 (len(non_ctrls), neighbor_type, len(ctrls), neighbor_type)) 1502 lines.append(" %d %s(s):" % (len(non_ctrls), neighbor_type)) 1503 for non_ctrl in non_ctrls: 1504 line = " [%s] %s" % (self._debug_dump.node_op_type(non_ctrl), 1505 non_ctrl) 1506 lines.append(line) 1507 font_attr_segs[len(lines) - 1] = [( 1508 len(line) - len(non_ctrl), len(line), 1509 debugger_cli_common.MenuItem(None, "ni -a -d -t %s" % non_ctrl))] 1510 1511 if ctrls: 1512 lines.append("") 1513 lines.append(" %d control %s(s):" % (len(ctrls), neighbor_type)) 1514 for ctrl in ctrls: 1515 line = " [%s] %s" % (self._debug_dump.node_op_type(ctrl), ctrl) 1516 lines.append(line) 1517 font_attr_segs[len(lines) - 1] = [( 1518 len(line) - len(ctrl), len(line), 1519 debugger_cli_common.MenuItem(None, "ni -a -d -t %s" % ctrl))] 1520 1521 return debugger_cli_common.RichTextLines( 1522 lines, font_attr_segs=font_attr_segs) 1523 1524 def _list_node_attributes(self, node_name): 1525 """List neighbors (inputs or recipients) of a node. 1526 1527 Args: 1528 node_name: Name of the node of which the attributes are to be listed. 1529 1530 Returns: 1531 A RichTextLines object. 1532 """ 1533 1534 lines = [] 1535 lines.append("") 1536 lines.append("Node attributes:") 1537 1538 attrs = self._debug_dump.node_attributes(node_name) 1539 for attr_key in attrs: 1540 lines.append(" %s:" % attr_key) 1541 attr_val_str = repr(attrs[attr_key]).strip().replace("\n", " ") 1542 lines.append(" %s" % attr_val_str) 1543 lines.append("") 1544 1545 return debugger_cli_common.RichTextLines(lines) 1546 1547 def _list_node_dumps(self, node_name): 1548 """List dumped tensor data from a node. 1549 1550 Args: 1551 node_name: Name of the node of which the attributes are to be listed. 1552 1553 Returns: 1554 A RichTextLines object. 1555 """ 1556 1557 lines = [] 1558 font_attr_segs = {} 1559 1560 watch_keys = self._debug_dump.debug_watch_keys(node_name) 1561 1562 dump_count = 0 1563 for watch_key in watch_keys: 1564 debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key) 1565 for datum in debug_tensor_data: 1566 line = " Slot %d @ %s @ %.3f ms" % ( 1567 datum.output_slot, datum.debug_op, 1568 (datum.timestamp - self._debug_dump.t0) / 1000.0) 1569 lines.append(line) 1570 command = "pt %s:%d -n %d" % (node_name, datum.output_slot, dump_count) 1571 font_attr_segs[len(lines) - 1] = [( 1572 2, len(line), debugger_cli_common.MenuItem(None, command))] 1573 dump_count += 1 1574 1575 output = debugger_cli_common.RichTextLines( 1576 lines, font_attr_segs=font_attr_segs) 1577 output_with_header = debugger_cli_common.RichTextLines( 1578 ["%d dumped tensor(s):" % dump_count, ""]) 1579 output_with_header.extend(output) 1580 return output_with_header 1581 1582 1583def create_analyzer_ui(debug_dump, 1584 tensor_filters=None, 1585 ui_type="curses", 1586 on_ui_exit=None, 1587 config=None): 1588 """Create an instance of CursesUI based on a DebugDumpDir object. 1589 1590 Args: 1591 debug_dump: (debug_data.DebugDumpDir) The debug dump to use. 1592 tensor_filters: (dict) A dict mapping tensor filter name (str) to tensor 1593 filter (Callable). 1594 ui_type: (str) requested UI type, e.g., "curses", "readline". 1595 on_ui_exit: (`Callable`) the callback to be called when the UI exits. 1596 config: A `cli_config.CLIConfig` object. 1597 1598 Returns: 1599 (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer 1600 commands and tab-completions registered. 1601 """ 1602 if config is None: 1603 config = cli_config.CLIConfig() 1604 1605 analyzer = DebugAnalyzer(debug_dump, config=config) 1606 if tensor_filters: 1607 for tensor_filter_name in tensor_filters: 1608 analyzer.add_tensor_filter( 1609 tensor_filter_name, tensor_filters[tensor_filter_name]) 1610 1611 cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit, config=config) 1612 cli.register_command_handler( 1613 "list_tensors", 1614 analyzer.list_tensors, 1615 analyzer.get_help("list_tensors"), 1616 prefix_aliases=["lt"]) 1617 cli.register_command_handler( 1618 "node_info", 1619 analyzer.node_info, 1620 analyzer.get_help("node_info"), 1621 prefix_aliases=["ni"]) 1622 cli.register_command_handler( 1623 "list_inputs", 1624 analyzer.list_inputs, 1625 analyzer.get_help("list_inputs"), 1626 prefix_aliases=["li"]) 1627 cli.register_command_handler( 1628 "list_outputs", 1629 analyzer.list_outputs, 1630 analyzer.get_help("list_outputs"), 1631 prefix_aliases=["lo"]) 1632 cli.register_command_handler( 1633 "print_tensor", 1634 analyzer.print_tensor, 1635 analyzer.get_help("print_tensor"), 1636 prefix_aliases=["pt"]) 1637 cli.register_command_handler( 1638 "print_source", 1639 analyzer.print_source, 1640 analyzer.get_help("print_source"), 1641 prefix_aliases=["ps"]) 1642 cli.register_command_handler( 1643 "list_source", 1644 analyzer.list_source, 1645 analyzer.get_help("list_source"), 1646 prefix_aliases=["ls"]) 1647 cli.register_command_handler( 1648 "eval", 1649 analyzer.evaluate_expression, 1650 analyzer.get_help("eval"), 1651 prefix_aliases=["ev"]) 1652 1653 dumped_tensor_names = [] 1654 for datum in debug_dump.dumped_tensor_data: 1655 dumped_tensor_names.append("%s:%d" % (datum.node_name, datum.output_slot)) 1656 1657 # Tab completions for command "print_tensors". 1658 cli.register_tab_comp_context(["print_tensor", "pt"], dumped_tensor_names) 1659 1660 return cli 1661