• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CLI Backend for the Analyzer Part of the Debugger.
16
17The analyzer performs post hoc analysis of dumped intermediate tensors and
18graph structure information from debugged Session.run() calls.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import argparse
25import copy
26import re
27
28from six.moves import xrange  # pylint: disable=redefined-builtin
29
30from tensorflow.python.debug.cli import cli_config
31from tensorflow.python.debug.cli import cli_shared
32from tensorflow.python.debug.cli import command_parser
33from tensorflow.python.debug.cli import debugger_cli_common
34from tensorflow.python.debug.cli import evaluator
35from tensorflow.python.debug.cli import ui_factory
36from tensorflow.python.debug.lib import debug_graphs
37from tensorflow.python.debug.lib import source_utils
38
39RL = debugger_cli_common.RichLine
40
41# String constants for the depth-dependent hanging indent at the beginning
42# of each line.
43HANG_UNFINISHED = "|  "  # Used for unfinished recursion depths.
44HANG_FINISHED = "   "
45HANG_SUFFIX = "|- "
46
47# String constant for displaying depth and op type.
48DEPTH_TEMPLATE = "(%d) "
49OP_TYPE_TEMPLATE = "[%s] "
50
51# String constants for control inputs/outputs, etc.
52CTRL_LABEL = "(Ctrl) "
53ELLIPSIS = "..."
54
55SORT_TENSORS_BY_TIMESTAMP = "timestamp"
56SORT_TENSORS_BY_DUMP_SIZE = "dump_size"
57SORT_TENSORS_BY_OP_TYPE = "op_type"
58SORT_TENSORS_BY_TENSOR_NAME = "tensor_name"
59
60
61def _add_main_menu(output,
62                   node_name=None,
63                   enable_list_tensors=True,
64                   enable_node_info=True,
65                   enable_print_tensor=True,
66                   enable_list_inputs=True,
67                   enable_list_outputs=True):
68  """Generate main menu for the screen output from a command.
69
70  Args:
71    output: (debugger_cli_common.RichTextLines) the output object to modify.
72    node_name: (str or None) name of the node involved (if any). If None,
73      the menu items node_info, list_inputs and list_outputs will be
74      automatically disabled, overriding the values of arguments
75      enable_node_info, enable_list_inputs and enable_list_outputs.
76    enable_list_tensors: (bool) whether the list_tensor menu item will be
77      enabled.
78    enable_node_info: (bool) whether the node_info item will be enabled.
79    enable_print_tensor: (bool) whether the print_tensor item will be enabled.
80    enable_list_inputs: (bool) whether the item list_inputs will be enabled.
81    enable_list_outputs: (bool) whether the item list_outputs will be enabled.
82  """
83
84  menu = debugger_cli_common.Menu()
85
86  menu.append(
87      debugger_cli_common.MenuItem(
88          "list_tensors", "list_tensors", enabled=enable_list_tensors))
89
90  if node_name:
91    menu.append(
92        debugger_cli_common.MenuItem(
93            "node_info",
94            "node_info -a -d -t %s" % node_name,
95            enabled=enable_node_info))
96    menu.append(
97        debugger_cli_common.MenuItem(
98            "print_tensor",
99            "print_tensor %s" % node_name,
100            enabled=enable_print_tensor))
101    menu.append(
102        debugger_cli_common.MenuItem(
103            "list_inputs",
104            "list_inputs -c -r %s" % node_name,
105            enabled=enable_list_inputs))
106    menu.append(
107        debugger_cli_common.MenuItem(
108            "list_outputs",
109            "list_outputs -c -r %s" % node_name,
110            enabled=enable_list_outputs))
111  else:
112    menu.append(
113        debugger_cli_common.MenuItem(
114            "node_info", None, enabled=False))
115    menu.append(
116        debugger_cli_common.MenuItem("print_tensor", None, enabled=False))
117    menu.append(
118        debugger_cli_common.MenuItem("list_inputs", None, enabled=False))
119    menu.append(
120        debugger_cli_common.MenuItem("list_outputs", None, enabled=False))
121
122  menu.append(
123      debugger_cli_common.MenuItem("run_info", "run_info"))
124  menu.append(
125      debugger_cli_common.MenuItem("help", "help"))
126
127  output.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu
128
129
130class DebugAnalyzer(object):
131  """Analyzer for debug data from dump directories."""
132
133  _TIMESTAMP_COLUMN_HEAD = "t (ms)"
134  _DUMP_SIZE_COLUMN_HEAD = "Size (B)"
135  _OP_TYPE_COLUMN_HEAD = "Op type"
136  _TENSOR_NAME_COLUMN_HEAD = "Tensor name"
137
138  # Op types to be omitted when generating descriptions of graph structure.
139  _GRAPH_STRUCT_OP_TYPE_DENYLIST = ("_Send", "_Recv", "_HostSend", "_HostRecv",
140                                    "_Retval")
141
142  def __init__(self, debug_dump, config):
143    """DebugAnalyzer constructor.
144
145    Args:
146      debug_dump: A DebugDumpDir object.
147      config: A `cli_config.CLIConfig` object that carries user-facing
148        configurations.
149    """
150
151    self._debug_dump = debug_dump
152    self._evaluator = evaluator.ExpressionEvaluator(self._debug_dump)
153
154    # Initialize tensor filters state.
155    self._tensor_filters = {}
156
157    self._build_argument_parsers(config)
158    config.set_callback("graph_recursion_depth",
159                        self._build_argument_parsers)
160
161    # TODO(cais): Implement list_nodes.
162
163  def _build_argument_parsers(self, config):
164    """Build argument parsers for DebugAnalayzer.
165
166    Args:
167      config: A `cli_config.CLIConfig` object.
168
169    Returns:
170      A dict mapping command handler name to `ArgumentParser` instance.
171    """
172    # Argument parsers for command handlers.
173    self._arg_parsers = {}
174
175    # Parser for list_tensors.
176    ap = argparse.ArgumentParser(
177        description="List dumped intermediate tensors.",
178        usage=argparse.SUPPRESS)
179    ap.add_argument(
180        "-f",
181        "--tensor_filter",
182        dest="tensor_filter",
183        type=str,
184        default="",
185        help="List only Tensors passing the filter of the specified name")
186    ap.add_argument(
187        "-fenn",
188        "--filter_exclude_node_names",
189        dest="filter_exclude_node_names",
190        type=str,
191        default="",
192        help="When applying the tensor filter, exclude node with names "
193        "matching the regular expression. Applicable only if --tensor_filter "
194        "or -f is used.")
195    ap.add_argument(
196        "-n",
197        "--node_name_filter",
198        dest="node_name_filter",
199        type=str,
200        default="",
201        help="filter node name by regex.")
202    ap.add_argument(
203        "-t",
204        "--op_type_filter",
205        dest="op_type_filter",
206        type=str,
207        default="",
208        help="filter op type by regex.")
209    ap.add_argument(
210        "-s",
211        "--sort_by",
212        dest="sort_by",
213        type=str,
214        default=SORT_TENSORS_BY_TIMESTAMP,
215        help=("the field to sort the data by: (%s | %s | %s | %s)" %
216              (SORT_TENSORS_BY_TIMESTAMP, SORT_TENSORS_BY_DUMP_SIZE,
217               SORT_TENSORS_BY_OP_TYPE, SORT_TENSORS_BY_TENSOR_NAME)))
218    ap.add_argument(
219        "-r",
220        "--reverse",
221        dest="reverse",
222        action="store_true",
223        help="sort the data in reverse (descending) order")
224    self._arg_parsers["list_tensors"] = ap
225
226    # Parser for node_info.
227    ap = argparse.ArgumentParser(
228        description="Show information about a node.", usage=argparse.SUPPRESS)
229    ap.add_argument(
230        "node_name",
231        type=str,
232        help="Name of the node or an associated tensor, e.g., "
233        "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0")
234    ap.add_argument(
235        "-a",
236        "--attributes",
237        dest="attributes",
238        action="store_true",
239        help="Also list attributes of the node.")
240    ap.add_argument(
241        "-d",
242        "--dumps",
243        dest="dumps",
244        action="store_true",
245        help="Also list dumps available from the node.")
246    ap.add_argument(
247        "-t",
248        "--traceback",
249        dest="traceback",
250        action="store_true",
251        help="Also include the traceback of the node's creation "
252        "(if available in Python).")
253    self._arg_parsers["node_info"] = ap
254
255    # Parser for list_inputs.
256    ap = argparse.ArgumentParser(
257        description="Show inputs to a node.", usage=argparse.SUPPRESS)
258    ap.add_argument(
259        "node_name",
260        type=str,
261        help="Name of the node or an output tensor from the node, e.g., "
262        "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0")
263    ap.add_argument(
264        "-c", "--control", action="store_true", help="Include control inputs.")
265    ap.add_argument(
266        "-d",
267        "--depth",
268        dest="depth",
269        type=int,
270        default=config.get("graph_recursion_depth"),
271        help="Maximum depth of recursion used when showing the input tree.")
272    ap.add_argument(
273        "-r",
274        "--recursive",
275        dest="recursive",
276        action="store_true",
277        help="Show inputs to the node recursively, i.e., the input tree.")
278    ap.add_argument(
279        "-t",
280        "--op_type",
281        action="store_true",
282        help="Show op types of input nodes.")
283    self._arg_parsers["list_inputs"] = ap
284
285    # Parser for list_outputs.
286    ap = argparse.ArgumentParser(
287        description="Show the nodes that receive the outputs of given node.",
288        usage=argparse.SUPPRESS)
289    ap.add_argument(
290        "node_name",
291        type=str,
292        help="Name of the node or an output tensor from the node, e.g., "
293        "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0")
294    ap.add_argument(
295        "-c", "--control", action="store_true", help="Include control inputs.")
296    ap.add_argument(
297        "-d",
298        "--depth",
299        dest="depth",
300        type=int,
301        default=config.get("graph_recursion_depth"),
302        help="Maximum depth of recursion used when showing the output tree.")
303    ap.add_argument(
304        "-r",
305        "--recursive",
306        dest="recursive",
307        action="store_true",
308        help="Show recipients of the node recursively, i.e., the output "
309        "tree.")
310    ap.add_argument(
311        "-t",
312        "--op_type",
313        action="store_true",
314        help="Show op types of recipient nodes.")
315    self._arg_parsers["list_outputs"] = ap
316
317    # Parser for print_tensor.
318    self._arg_parsers["print_tensor"] = (
319        command_parser.get_print_tensor_argparser(
320            "Print the value of a dumped tensor."))
321
322    # Parser for print_source.
323    ap = argparse.ArgumentParser(
324        description="Print a Python source file with overlaid debug "
325        "information, including the nodes (ops) or Tensors created at the "
326        "source lines.",
327        usage=argparse.SUPPRESS)
328    ap.add_argument(
329        "source_file_path",
330        type=str,
331        help="Path to the source file.")
332    ap.add_argument(
333        "-t",
334        "--tensors",
335        dest="tensors",
336        action="store_true",
337        help="Label lines with dumped Tensors, instead of ops.")
338    ap.add_argument(
339        "-m",
340        "--max_elements_per_line",
341        type=int,
342        default=10,
343        help="Maximum number of elements (ops or Tensors) to show per source "
344             "line.")
345    ap.add_argument(
346        "-b",
347        "--line_begin",
348        type=int,
349        default=1,
350        help="Print source beginning at line number (1-based.)")
351    self._arg_parsers["print_source"] = ap
352
353    # Parser for list_source.
354    ap = argparse.ArgumentParser(
355        description="List source files responsible for constructing nodes and "
356        "tensors present in the run().",
357        usage=argparse.SUPPRESS)
358    ap.add_argument(
359        "-p",
360        "--path_filter",
361        type=str,
362        default="",
363        help="Regular expression filter for file path.")
364    ap.add_argument(
365        "-n",
366        "--node_name_filter",
367        type=str,
368        default="",
369        help="Regular expression filter for node name.")
370    self._arg_parsers["list_source"] = ap
371
372    # Parser for eval.
373    ap = argparse.ArgumentParser(
374        description="""Evaluate an arbitrary expression. Can use tensor values
375        from the current debug dump. The debug tensor names should be enclosed
376        in pairs of backticks. Expressions with spaces should be enclosed in
377        a pair of double quotes or a pair of single quotes. By default, numpy
378        is imported as np and can be used in the expressions. E.g.,
379          1) eval np.argmax(`Softmax:0`),
380          2) eval 'np.sum(`Softmax:0`, axis=1)',
381          3) eval "np.matmul((`output/Identity:0`/`Softmax:0`).T, `Softmax:0`)".
382        """,
383        usage=argparse.SUPPRESS)
384    ap.add_argument(
385        "expression",
386        type=str,
387        help="""Expression to be evaluated.
388        1) in the simplest case, use <node_name>:<output_slot>, e.g.,
389          hidden_0/MatMul:0.
390
391        2) if the default debug op "DebugIdentity" is to be overridden, use
392          <node_name>:<output_slot>:<debug_op>, e.g.,
393          hidden_0/MatMul:0:DebugNumericSummary.
394
395        3) if the tensor of the same name exists on more than one device, use
396          <device_name>:<node_name>:<output_slot>[:<debug_op>], e.g.,
397          /job:worker/replica:0/task:0/gpu:0:hidden_0/MatMul:0
398          /job:worker/replica:0/task:2/cpu:0:hidden_0/MatMul:0:DebugNanCount.
399
400        4) if the tensor is executed multiple times in a given `Session.run`
401        call, specify the execution index with a 0-based integer enclose in a
402        pair of brackets at the end, e.g.,
403          RNN/tanh:0[0]
404          /job:worker/replica:0/task:0/gpu:0:RNN/tanh:0[0].""")
405    ap.add_argument(
406        "-a",
407        "--all",
408        dest="print_all",
409        action="store_true",
410        help="Print the tensor in its entirety, i.e., do not use ellipses "
411        "(may be slow for large results).")
412    ap.add_argument(
413        "-w",
414        "--write_path",
415        default="",
416        help="Path of the numpy file to write the evaluation result to, "
417        "using numpy.save()")
418    self._arg_parsers["eval"] = ap
419
420  def add_tensor_filter(self, filter_name, filter_callable):
421    """Add a tensor filter.
422
423    A tensor filter is a named callable of the signature:
424      filter_callable(dump_datum, tensor),
425
426    wherein dump_datum is an instance of debug_data.DebugTensorDatum carrying
427    metadata about the dumped tensor, including tensor name, timestamps, etc.
428    tensor is the value of the dumped tensor as an numpy.ndarray object.
429    The return value of the function is a bool.
430    This is the same signature as the input argument to
431    debug_data.DebugDumpDir.find().
432
433    Args:
434      filter_name: (str) name of the filter. Cannot be empty.
435      filter_callable: (callable) a filter function of the signature described
436        as above.
437
438    Raises:
439      ValueError: If filter_name is an empty str.
440      TypeError: If filter_name is not a str.
441                 Or if filter_callable is not callable.
442    """
443
444    if not isinstance(filter_name, str):
445      raise TypeError("Input argument filter_name is expected to be str, "
446                      "but is not.")
447
448    # Check that filter_name is not an empty str.
449    if not filter_name:
450      raise ValueError("Input argument filter_name cannot be empty.")
451
452    # Check that filter_callable is callable.
453    if not callable(filter_callable):
454      raise TypeError(
455          "Input argument filter_callable is expected to be callable, "
456          "but is not.")
457
458    self._tensor_filters[filter_name] = filter_callable
459
460  def get_tensor_filter(self, filter_name):
461    """Retrieve filter function by name.
462
463    Args:
464      filter_name: Name of the filter set during add_tensor_filter() call.
465
466    Returns:
467      The callable associated with the filter name.
468
469    Raises:
470      ValueError: If there is no tensor filter of the specified filter name.
471    """
472
473    if filter_name not in self._tensor_filters:
474      raise ValueError("There is no tensor filter named \"%s\"" % filter_name)
475
476    return self._tensor_filters[filter_name]
477
478  def get_help(self, handler_name):
479    return self._arg_parsers[handler_name].format_help()
480
481  def list_tensors(self, args, screen_info=None):
482    """Command handler for list_tensors.
483
484    List tensors dumped during debugged Session.run() call.
485
486    Args:
487      args: Command-line arguments, excluding the command prefix, as a list of
488        str.
489      screen_info: Optional dict input containing screen information such as
490        cols.
491
492    Returns:
493      Output text lines as a RichTextLines object.
494
495    Raises:
496      ValueError: If `--filter_exclude_node_names` is used without `-f` or
497        `--tensor_filter` being used.
498    """
499
500    # TODO(cais): Add annotations of substrings for dumped tensor names, to
501    # facilitate on-screen highlighting/selection of node names.
502    _ = screen_info
503
504    parsed = self._arg_parsers["list_tensors"].parse_args(args)
505
506    output = []
507
508    filter_strs = []
509    if parsed.op_type_filter:
510      op_type_regex = re.compile(parsed.op_type_filter)
511      filter_strs.append("Op type regex filter: \"%s\"" % parsed.op_type_filter)
512    else:
513      op_type_regex = None
514
515    if parsed.node_name_filter:
516      node_name_regex = re.compile(parsed.node_name_filter)
517      filter_strs.append("Node name regex filter: \"%s\"" %
518                         parsed.node_name_filter)
519    else:
520      node_name_regex = None
521
522    output = debugger_cli_common.RichTextLines(filter_strs)
523    output.append("")
524
525    if parsed.tensor_filter:
526      try:
527        filter_callable = self.get_tensor_filter(parsed.tensor_filter)
528      except ValueError:
529        output = cli_shared.error("There is no tensor filter named \"%s\"." %
530                                  parsed.tensor_filter)
531        _add_main_menu(output, node_name=None, enable_list_tensors=False)
532        return output
533
534      data_to_show = self._debug_dump.find(
535          filter_callable,
536          exclude_node_names=parsed.filter_exclude_node_names)
537    else:
538      if parsed.filter_exclude_node_names:
539        raise ValueError(
540            "The flag --filter_exclude_node_names is valid only when "
541            "the flag -f or --tensor_filter is used.")
542
543      data_to_show = self._debug_dump.dumped_tensor_data
544
545    # TODO(cais): Implement filter by lambda on tensor value.
546
547    max_timestamp_width, max_dump_size_width, max_op_type_width = (
548        self._measure_tensor_list_column_widths(data_to_show))
549
550    # Sort the data.
551    data_to_show = self._sort_dump_data_by(
552        data_to_show, parsed.sort_by, parsed.reverse)
553
554    output.extend(
555        self._tensor_list_column_heads(parsed, max_timestamp_width,
556                                       max_dump_size_width, max_op_type_width))
557
558    dump_count = 0
559    for dump in data_to_show:
560      if node_name_regex and not node_name_regex.match(dump.node_name):
561        continue
562
563      if op_type_regex:
564        op_type = self._debug_dump.node_op_type(dump.node_name)
565        if not op_type_regex.match(op_type):
566          continue
567
568      rel_time = (dump.timestamp - self._debug_dump.t0) / 1000.0
569      dump_size_str = cli_shared.bytes_to_readable_str(dump.dump_size_bytes)
570      dumped_tensor_name = "%s:%d" % (dump.node_name, dump.output_slot)
571      op_type = self._debug_dump.node_op_type(dump.node_name)
572
573      line = "[%.3f]" % rel_time
574      line += " " * (max_timestamp_width - len(line))
575      line += dump_size_str
576      line += " " * (max_timestamp_width + max_dump_size_width - len(line))
577      line += op_type
578      line += " " * (max_timestamp_width + max_dump_size_width +
579                     max_op_type_width - len(line))
580      line += dumped_tensor_name
581
582      output.append(
583          line,
584          font_attr_segs=[(
585              len(line) - len(dumped_tensor_name), len(line),
586              debugger_cli_common.MenuItem("", "pt %s" % dumped_tensor_name))])
587      dump_count += 1
588
589    if parsed.tensor_filter:
590      output.prepend([
591          "%d dumped tensor(s) passing filter \"%s\":" %
592          (dump_count, parsed.tensor_filter)
593      ])
594    else:
595      output.prepend(["%d dumped tensor(s):" % dump_count])
596
597    _add_main_menu(output, node_name=None, enable_list_tensors=False)
598    return output
599
600  def _measure_tensor_list_column_widths(self, data):
601    """Determine the maximum widths of the timestamp and op-type column.
602
603    This method assumes that data is sorted in the default order, i.e.,
604    by ascending timestamps.
605
606    Args:
607      data: (list of DebugTensorDaum) the data based on which the maximum
608        column widths will be determined.
609
610    Returns:
611      (int) maximum width of the timestamp column. 0 if data is empty.
612      (int) maximum width of the dump size column. 0 if data is empty.
613      (int) maximum width of the op type column. 0 if data is empty.
614    """
615
616    max_timestamp_width = 0
617    if data:
618      max_rel_time_ms = (data[-1].timestamp - self._debug_dump.t0) / 1000.0
619      max_timestamp_width = len("[%.3f] " % max_rel_time_ms) + 1
620    max_timestamp_width = max(max_timestamp_width,
621                              len(self._TIMESTAMP_COLUMN_HEAD) + 1)
622
623    max_dump_size_width = 0
624    for dump in data:
625      dump_size_str = cli_shared.bytes_to_readable_str(dump.dump_size_bytes)
626      if len(dump_size_str) + 1 > max_dump_size_width:
627        max_dump_size_width = len(dump_size_str) + 1
628    max_dump_size_width = max(max_dump_size_width,
629                              len(self._DUMP_SIZE_COLUMN_HEAD) + 1)
630
631    max_op_type_width = 0
632    for dump in data:
633      op_type = self._debug_dump.node_op_type(dump.node_name)
634      if len(op_type) + 1 > max_op_type_width:
635        max_op_type_width = len(op_type) + 1
636    max_op_type_width = max(max_op_type_width,
637                            len(self._OP_TYPE_COLUMN_HEAD) + 1)
638
639    return max_timestamp_width, max_dump_size_width, max_op_type_width
640
641  def _sort_dump_data_by(self, data, sort_by, reverse):
642    """Sort a list of DebugTensorDatum in specified order.
643
644    Args:
645      data: (list of DebugTensorDatum) the data to be sorted.
646      sort_by: The field to sort data by.
647      reverse: (bool) Whether to use reversed (descending) order.
648
649    Returns:
650      (list of DebugTensorDatum) in sorted order.
651
652    Raises:
653      ValueError: given an invalid value of sort_by.
654    """
655
656    if sort_by == SORT_TENSORS_BY_TIMESTAMP:
657      return sorted(
658          data,
659          reverse=reverse,
660          key=lambda x: x.timestamp)
661    elif sort_by == SORT_TENSORS_BY_DUMP_SIZE:
662      return sorted(data, reverse=reverse, key=lambda x: x.dump_size_bytes)
663    elif sort_by == SORT_TENSORS_BY_OP_TYPE:
664      return sorted(
665          data,
666          reverse=reverse,
667          key=lambda x: self._debug_dump.node_op_type(x.node_name))
668    elif sort_by == SORT_TENSORS_BY_TENSOR_NAME:
669      return sorted(
670          data,
671          reverse=reverse,
672          key=lambda x: "%s:%d" % (x.node_name, x.output_slot))
673    else:
674      raise ValueError("Unsupported key to sort tensors by: %s" % sort_by)
675
676  def _tensor_list_column_heads(self, parsed, max_timestamp_width,
677                                max_dump_size_width, max_op_type_width):
678    """Generate a line containing the column heads of the tensor list.
679
680    Args:
681      parsed: Parsed arguments (by argparse) of the list_tensors command.
682      max_timestamp_width: (int) maximum width of the timestamp column.
683      max_dump_size_width: (int) maximum width of the dump size column.
684      max_op_type_width: (int) maximum width of the op type column.
685
686    Returns:
687      A RichTextLines object.
688    """
689
690    base_command = "list_tensors"
691    if parsed.tensor_filter:
692      base_command += " -f %s" % parsed.tensor_filter
693    if parsed.op_type_filter:
694      base_command += " -t %s" % parsed.op_type_filter
695    if parsed.node_name_filter:
696      base_command += " -n %s" % parsed.node_name_filter
697
698    attr_segs = {0: []}
699    row = self._TIMESTAMP_COLUMN_HEAD
700    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_TIMESTAMP)
701    if parsed.sort_by == SORT_TENSORS_BY_TIMESTAMP and not parsed.reverse:
702      command += " -r"
703    attr_segs[0].append(
704        (0, len(row), [debugger_cli_common.MenuItem(None, command), "bold"]))
705    row += " " * (max_timestamp_width - len(row))
706
707    prev_len = len(row)
708    row += self._DUMP_SIZE_COLUMN_HEAD
709    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_DUMP_SIZE)
710    if parsed.sort_by == SORT_TENSORS_BY_DUMP_SIZE and not parsed.reverse:
711      command += " -r"
712    attr_segs[0].append((prev_len, len(row),
713                         [debugger_cli_common.MenuItem(None, command), "bold"]))
714    row += " " * (max_dump_size_width + max_timestamp_width - len(row))
715
716    prev_len = len(row)
717    row += self._OP_TYPE_COLUMN_HEAD
718    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_OP_TYPE)
719    if parsed.sort_by == SORT_TENSORS_BY_OP_TYPE and not parsed.reverse:
720      command += " -r"
721    attr_segs[0].append((prev_len, len(row),
722                         [debugger_cli_common.MenuItem(None, command), "bold"]))
723    row += " " * (
724        max_op_type_width + max_dump_size_width + max_timestamp_width - len(row)
725    )
726
727    prev_len = len(row)
728    row += self._TENSOR_NAME_COLUMN_HEAD
729    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_TENSOR_NAME)
730    if parsed.sort_by == SORT_TENSORS_BY_TENSOR_NAME and not parsed.reverse:
731      command += " -r"
732    attr_segs[0].append((prev_len, len(row),
733                         [debugger_cli_common.MenuItem("", command), "bold"]))
734    row += " " * (
735        max_op_type_width + max_dump_size_width + max_timestamp_width - len(row)
736    )
737
738    return debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs)
739
740  def node_info(self, args, screen_info=None):
741    """Command handler for node_info.
742
743    Query information about a given node.
744
745    Args:
746      args: Command-line arguments, excluding the command prefix, as a list of
747        str.
748      screen_info: Optional dict input containing screen information such as
749        cols.
750
751    Returns:
752      Output text lines as a RichTextLines object.
753    """
754
755    # TODO(cais): Add annotation of substrings for node names, to facilitate
756    # on-screen highlighting/selection of node names.
757    _ = screen_info
758
759    parsed = self._arg_parsers["node_info"].parse_args(args)
760
761    # Get a node name, regardless of whether the input is a node name (without
762    # output slot attached) or a tensor name (with output slot attached).
763    node_name, unused_slot = debug_graphs.parse_node_or_tensor_name(
764        parsed.node_name)
765
766    if not self._debug_dump.node_exists(node_name):
767      output = cli_shared.error(
768          "There is no node named \"%s\" in the partition graphs" % node_name)
769      _add_main_menu(
770          output,
771          node_name=None,
772          enable_list_tensors=True,
773          enable_node_info=False,
774          enable_list_inputs=False,
775          enable_list_outputs=False)
776      return output
777
778    # TODO(cais): Provide UI glossary feature to explain to users what the
779    # term "partition graph" means and how it is related to TF graph objects
780    # in Python. The information can be along the line of:
781    # "A tensorflow graph defined in Python is stripped of unused ops
782    # according to the feeds and fetches and divided into a number of
783    # partition graphs that may be distributed among multiple devices and
784    # hosts. The partition graphs are what's actually executed by the C++
785    # runtime during a run() call."
786
787    lines = ["Node %s" % node_name]
788    font_attr_segs = {
789        0: [(len(lines[-1]) - len(node_name), len(lines[-1]), "bold")]
790    }
791    lines.append("")
792    lines.append("  Op: %s" % self._debug_dump.node_op_type(node_name))
793    lines.append("  Device: %s" % self._debug_dump.node_device(node_name))
794    output = debugger_cli_common.RichTextLines(
795        lines, font_attr_segs=font_attr_segs)
796
797    # List node inputs (non-control and control).
798    inputs = self._exclude_denylisted_ops(
799        self._debug_dump.node_inputs(node_name))
800    ctrl_inputs = self._exclude_denylisted_ops(
801        self._debug_dump.node_inputs(node_name, is_control=True))
802    output.extend(self._format_neighbors("input", inputs, ctrl_inputs))
803
804    # List node output recipients (non-control and control).
805    recs = self._exclude_denylisted_ops(
806        self._debug_dump.node_recipients(node_name))
807    ctrl_recs = self._exclude_denylisted_ops(
808        self._debug_dump.node_recipients(node_name, is_control=True))
809    output.extend(self._format_neighbors("recipient", recs, ctrl_recs))
810
811    # Optional: List attributes of the node.
812    if parsed.attributes:
813      output.extend(self._list_node_attributes(node_name))
814
815    # Optional: List dumps available from the node.
816    if parsed.dumps:
817      output.extend(self._list_node_dumps(node_name))
818
819    if parsed.traceback:
820      output.extend(self._render_node_traceback(node_name))
821
822    _add_main_menu(output, node_name=node_name, enable_node_info=False)
823    return output
824
825  def _exclude_denylisted_ops(self, node_names):
826    """Exclude all nodes whose op types are in _GRAPH_STRUCT_OP_TYPE_DENYLIST.
827
828    Args:
829      node_names: An iterable of node or graph element names.
830
831    Returns:
832      A list of node names that are not denylisted.
833    """
834    return [
835        node_name for node_name in node_names
836        if self._debug_dump.node_op_type(debug_graphs.get_node_name(node_name))
837        not in self._GRAPH_STRUCT_OP_TYPE_DENYLIST
838    ]
839
840  def _render_node_traceback(self, node_name):
841    """Render traceback of a node's creation in Python, if available.
842
843    Args:
844      node_name: (str) name of the node.
845
846    Returns:
847      A RichTextLines object containing the stack trace of the node's
848      construction.
849    """
850
851    lines = [RL(""), RL(""), RL("Traceback of node construction:", "bold")]
852
853    try:
854      node_stack = self._debug_dump.node_traceback(node_name)
855      for depth, (file_path, line, function_name, text) in enumerate(
856          node_stack):
857        lines.append("%d: %s" % (depth, file_path))
858
859        attribute = debugger_cli_common.MenuItem(
860            "", "ps %s -b %d" % (file_path, line)) if text else None
861        line_number_line = RL("  ")
862        line_number_line += RL("Line:     %d" % line, attribute)
863        lines.append(line_number_line)
864
865        lines.append("  Function: %s" % function_name)
866        lines.append("  Text:     " + (("\"%s\"" % text) if text else "None"))
867        lines.append("")
868    except KeyError:
869      lines.append("(Node unavailable in the loaded Python graph)")
870    except LookupError:
871      lines.append("(Unavailable because no Python graph has been loaded)")
872
873    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
874
875  def list_inputs(self, args, screen_info=None):
876    """Command handler for inputs.
877
878    Show inputs to a given node.
879
880    Args:
881      args: Command-line arguments, excluding the command prefix, as a list of
882        str.
883      screen_info: Optional dict input containing screen information such as
884        cols.
885
886    Returns:
887      Output text lines as a RichTextLines object.
888    """
889
890    # Screen info not currently used by this handler. Include this line to
891    # mute pylint.
892    _ = screen_info
893    # TODO(cais): Use screen info to format the output lines more prettily,
894    # e.g., hanging indent of long node names.
895
896    parsed = self._arg_parsers["list_inputs"].parse_args(args)
897
898    output = self._list_inputs_or_outputs(
899        parsed.recursive,
900        parsed.node_name,
901        parsed.depth,
902        parsed.control,
903        parsed.op_type,
904        do_outputs=False)
905
906    node_name = debug_graphs.get_node_name(parsed.node_name)
907    _add_main_menu(output, node_name=node_name, enable_list_inputs=False)
908
909    return output
910
911  def print_tensor(self, args, screen_info=None):
912    """Command handler for print_tensor.
913
914    Print value of a given dumped tensor.
915
916    Args:
917      args: Command-line arguments, excluding the command prefix, as a list of
918        str.
919      screen_info: Optional dict input containing screen information such as
920        cols.
921
922    Returns:
923      Output text lines as a RichTextLines object.
924    """
925
926    parsed = self._arg_parsers["print_tensor"].parse_args(args)
927
928    np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
929        screen_info)
930
931    # Determine if any range-highlighting is required.
932    highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)
933
934    tensor_name, tensor_slicing = (
935        command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
936
937    node_name, output_slot = debug_graphs.parse_node_or_tensor_name(tensor_name)
938    if (self._debug_dump.loaded_partition_graphs() and
939        not self._debug_dump.node_exists(node_name)):
940      output = cli_shared.error(
941          "Node \"%s\" does not exist in partition graphs" % node_name)
942      _add_main_menu(
943          output,
944          node_name=None,
945          enable_list_tensors=True,
946          enable_print_tensor=False)
947      return output
948
949    watch_keys = self._debug_dump.debug_watch_keys(node_name)
950    if output_slot is None:
951      output_slots = set()
952      for watch_key in watch_keys:
953        output_slots.add(int(watch_key.split(":")[1]))
954
955      if len(output_slots) == 1:
956        # There is only one dumped tensor from this node, so there is no
957        # ambiguity. Proceed to show the only dumped tensor.
958        output_slot = list(output_slots)[0]
959      else:
960        # There are more than one dumped tensors from this node. Indicate as
961        # such.
962        # TODO(cais): Provide an output screen with command links for
963        # convenience.
964        lines = [
965            "Node \"%s\" generated debug dumps from %s output slots:" %
966            (node_name, len(output_slots)),
967            "Please specify the output slot: %s:x." % node_name
968        ]
969        output = debugger_cli_common.RichTextLines(lines)
970        _add_main_menu(
971            output,
972            node_name=node_name,
973            enable_list_tensors=True,
974            enable_print_tensor=False)
975        return output
976
977    # Find debug dump data that match the tensor name (node name + output
978    # slot).
979    matching_data = []
980    for watch_key in watch_keys:
981      debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key)
982      for datum in debug_tensor_data:
983        if datum.output_slot == output_slot:
984          matching_data.append(datum)
985
986    if not matching_data:
987      # No dump for this tensor.
988      output = cli_shared.error("Tensor \"%s\" did not generate any dumps." %
989                                parsed.tensor_name)
990    elif len(matching_data) == 1:
991      # There is only one dump for this tensor.
992      if parsed.number <= 0:
993        output = cli_shared.format_tensor(
994            matching_data[0].get_tensor(),
995            matching_data[0].watch_key,
996            np_printoptions,
997            print_all=parsed.print_all,
998            tensor_slicing=tensor_slicing,
999            highlight_options=highlight_options,
1000            include_numeric_summary=parsed.numeric_summary,
1001            write_path=parsed.write_path)
1002      else:
1003        output = cli_shared.error(
1004            "Invalid number (%d) for tensor %s, which generated one dump." %
1005            (parsed.number, parsed.tensor_name))
1006
1007      _add_main_menu(output, node_name=node_name, enable_print_tensor=False)
1008    else:
1009      # There are more than one dumps for this tensor.
1010      if parsed.number < 0:
1011        lines = [
1012            "Tensor \"%s\" generated %d dumps:" % (parsed.tensor_name,
1013                                                   len(matching_data))
1014        ]
1015        font_attr_segs = {}
1016
1017        for i, datum in enumerate(matching_data):
1018          rel_time = (datum.timestamp - self._debug_dump.t0) / 1000.0
1019          lines.append("#%d [%.3f ms] %s" % (i, rel_time, datum.watch_key))
1020          command = "print_tensor %s -n %d" % (parsed.tensor_name, i)
1021          font_attr_segs[len(lines) - 1] = [(
1022              len(lines[-1]) - len(datum.watch_key), len(lines[-1]),
1023              debugger_cli_common.MenuItem(None, command))]
1024
1025        lines.append("")
1026        lines.append(
1027            "You can use the -n (--number) flag to specify which dump to "
1028            "print.")
1029        lines.append("For example:")
1030        lines.append("  print_tensor %s -n 0" % parsed.tensor_name)
1031
1032        output = debugger_cli_common.RichTextLines(
1033            lines, font_attr_segs=font_attr_segs)
1034      elif parsed.number >= len(matching_data):
1035        output = cli_shared.error(
1036            "Specified number (%d) exceeds the number of available dumps "
1037            "(%d) for tensor %s" %
1038            (parsed.number, len(matching_data), parsed.tensor_name))
1039      else:
1040        output = cli_shared.format_tensor(
1041            matching_data[parsed.number].get_tensor(),
1042            matching_data[parsed.number].watch_key + " (dump #%d)" %
1043            parsed.number,
1044            np_printoptions,
1045            print_all=parsed.print_all,
1046            tensor_slicing=tensor_slicing,
1047            highlight_options=highlight_options,
1048            write_path=parsed.write_path)
1049      _add_main_menu(output, node_name=node_name, enable_print_tensor=False)
1050
1051    return output
1052
1053  def list_outputs(self, args, screen_info=None):
1054    """Command handler for inputs.
1055
1056    Show inputs to a given node.
1057
1058    Args:
1059      args: Command-line arguments, excluding the command prefix, as a list of
1060        str.
1061      screen_info: Optional dict input containing screen information such as
1062        cols.
1063
1064    Returns:
1065      Output text lines as a RichTextLines object.
1066    """
1067
1068    # Screen info not currently used by this handler. Include this line to
1069    # mute pylint.
1070    _ = screen_info
1071    # TODO(cais): Use screen info to format the output lines more prettily,
1072    # e.g., hanging indent of long node names.
1073
1074    parsed = self._arg_parsers["list_outputs"].parse_args(args)
1075
1076    output = self._list_inputs_or_outputs(
1077        parsed.recursive,
1078        parsed.node_name,
1079        parsed.depth,
1080        parsed.control,
1081        parsed.op_type,
1082        do_outputs=True)
1083
1084    node_name = debug_graphs.get_node_name(parsed.node_name)
1085    _add_main_menu(output, node_name=node_name, enable_list_outputs=False)
1086
1087    return output
1088
1089  def evaluate_expression(self, args, screen_info=None):
1090    parsed = self._arg_parsers["eval"].parse_args(args)
1091
1092    eval_res = self._evaluator.evaluate(parsed.expression)
1093
1094    np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
1095        screen_info)
1096    return cli_shared.format_tensor(
1097        eval_res,
1098        "from eval of expression '%s'" % parsed.expression,
1099        np_printoptions,
1100        print_all=parsed.print_all,
1101        include_numeric_summary=True,
1102        write_path=parsed.write_path)
1103
1104  def _reconstruct_print_source_command(self,
1105                                        parsed,
1106                                        line_begin,
1107                                        max_elements_per_line_increase=0):
1108    return "ps %s %s -b %d -m %d" % (
1109        parsed.source_file_path, "-t" if parsed.tensors else "", line_begin,
1110        parsed.max_elements_per_line + max_elements_per_line_increase)
1111
1112  def print_source(self, args, screen_info=None):
1113    """Print the content of a source file."""
1114    del screen_info  # Unused.
1115
1116    parsed = self._arg_parsers["print_source"].parse_args(args)
1117
1118    source_annotation = source_utils.annotate_source(
1119        self._debug_dump,
1120        parsed.source_file_path,
1121        do_dumped_tensors=parsed.tensors)
1122
1123    source_lines, line_num_width = source_utils.load_source(
1124        parsed.source_file_path)
1125
1126    labeled_source_lines = []
1127    actual_initial_scroll_target = 0
1128    for i, line in enumerate(source_lines):
1129      annotated_line = RL("L%d" % (i + 1), cli_shared.COLOR_YELLOW)
1130      annotated_line += " " * (line_num_width - len(annotated_line))
1131      annotated_line += line
1132      labeled_source_lines.append(annotated_line)
1133
1134      if i + 1 == parsed.line_begin:
1135        actual_initial_scroll_target = len(labeled_source_lines) - 1
1136
1137      if i + 1 in source_annotation:
1138        sorted_elements = sorted(source_annotation[i + 1])
1139        for k, element in enumerate(sorted_elements):
1140          if k >= parsed.max_elements_per_line:
1141            omitted_info_line = RL("    (... Omitted %d of %d %s ...) " % (
1142                len(sorted_elements) - parsed.max_elements_per_line,
1143                len(sorted_elements),
1144                "tensor(s)" if parsed.tensors else "op(s)"))
1145            omitted_info_line += RL(
1146                "+5",
1147                debugger_cli_common.MenuItem(
1148                    None,
1149                    self._reconstruct_print_source_command(
1150                        parsed, i + 1, max_elements_per_line_increase=5)))
1151            labeled_source_lines.append(omitted_info_line)
1152            break
1153
1154          label = RL(" " * 4)
1155          if self._debug_dump.debug_watch_keys(
1156              debug_graphs.get_node_name(element)):
1157            attribute = debugger_cli_common.MenuItem("", "pt %s" % element)
1158          else:
1159            attribute = cli_shared.COLOR_BLUE
1160
1161          label += RL(element, attribute)
1162          labeled_source_lines.append(label)
1163
1164    output = debugger_cli_common.rich_text_lines_from_rich_line_list(
1165        labeled_source_lines,
1166        annotations={debugger_cli_common.INIT_SCROLL_POS_KEY:
1167                     actual_initial_scroll_target})
1168    _add_main_menu(output, node_name=None)
1169    return output
1170
1171  def _make_source_table(self, source_list, is_tf_py_library):
1172    """Make a table summarizing the source files that create nodes and tensors.
1173
1174    Args:
1175      source_list: List of source files and related information as a list of
1176        tuples (file_path, is_tf_library, num_nodes, num_tensors, num_dumps,
1177        first_line).
1178      is_tf_py_library: (`bool`) whether this table is for files that belong
1179        to the TensorFlow Python library.
1180
1181    Returns:
1182      The table as a `debugger_cli_common.RichTextLines` object.
1183    """
1184    path_head = "Source file path"
1185    num_nodes_head = "#(nodes)"
1186    num_tensors_head = "#(tensors)"
1187    num_dumps_head = "#(tensor dumps)"
1188
1189    if is_tf_py_library:
1190      # Use color to mark files that are guessed to belong to TensorFlow Python
1191      # library.
1192      color = cli_shared.COLOR_GRAY
1193      lines = [RL("TensorFlow Python library file(s):", color)]
1194    else:
1195      color = cli_shared.COLOR_WHITE
1196      lines = [RL("File(s) outside TensorFlow Python library:", color)]
1197
1198    if not source_list:
1199      lines.append(RL("[No files.]"))
1200      lines.append(RL())
1201      return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
1202
1203    path_column_width = max(
1204        max(len(item[0]) for item in source_list), len(path_head)) + 1
1205    num_nodes_column_width = max(
1206        max(len(str(item[2])) for item in source_list),
1207        len(num_nodes_head)) + 1
1208    num_tensors_column_width = max(
1209        max(len(str(item[3])) for item in source_list),
1210        len(num_tensors_head)) + 1
1211
1212    head = RL(path_head + " " * (path_column_width - len(path_head)), color)
1213    head += RL(num_nodes_head + " " * (
1214        num_nodes_column_width - len(num_nodes_head)), color)
1215    head += RL(num_tensors_head + " " * (
1216        num_tensors_column_width - len(num_tensors_head)), color)
1217    head += RL(num_dumps_head, color)
1218
1219    lines.append(head)
1220
1221    for (file_path, _, num_nodes, num_tensors, num_dumps,
1222         first_line_num) in source_list:
1223      path_attributes = [color]
1224      if source_utils.is_extension_uncompiled_python_source(file_path):
1225        path_attributes.append(
1226            debugger_cli_common.MenuItem(None, "ps %s -b %d" %
1227                                         (file_path, first_line_num)))
1228
1229      line = RL(file_path, path_attributes)
1230      line += " " * (path_column_width - len(line))
1231      line += RL(
1232          str(num_nodes) + " " * (num_nodes_column_width - len(str(num_nodes))),
1233          color)
1234      line += RL(
1235          str(num_tensors) + " " *
1236          (num_tensors_column_width - len(str(num_tensors))), color)
1237      line += RL(str(num_dumps), color)
1238      lines.append(line)
1239    lines.append(RL())
1240
1241    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
1242
1243  def list_source(self, args, screen_info=None):
1244    """List Python source files that constructed nodes and tensors."""
1245    del screen_info  # Unused.
1246
1247    parsed = self._arg_parsers["list_source"].parse_args(args)
1248    source_list = source_utils.list_source_files_against_dump(
1249        self._debug_dump,
1250        path_regex_allowlist=parsed.path_filter,
1251        node_name_regex_allowlist=parsed.node_name_filter)
1252
1253    top_lines = [
1254        RL("List of source files that created nodes in this run", "bold")]
1255    if parsed.path_filter:
1256      top_lines.append(
1257          RL("File path regex filter: \"%s\"" % parsed.path_filter))
1258    if parsed.node_name_filter:
1259      top_lines.append(
1260          RL("Node name regex filter: \"%s\"" % parsed.node_name_filter))
1261    top_lines.append(RL())
1262    output = debugger_cli_common.rich_text_lines_from_rich_line_list(top_lines)
1263    if not source_list:
1264      output.append("[No source file information.]")
1265      return output
1266
1267    output.extend(self._make_source_table(
1268        [item for item in source_list if not item[1]], False))
1269    output.extend(self._make_source_table(
1270        [item for item in source_list if item[1]], True))
1271    _add_main_menu(output, node_name=None)
1272    return output
1273
1274  def _list_inputs_or_outputs(self,
1275                              recursive,
1276                              node_name,
1277                              depth,
1278                              control,
1279                              op_type,
1280                              do_outputs=False):
1281    """Helper function used by list_inputs and list_outputs.
1282
1283    Format a list of lines to display the inputs or output recipients of a
1284    given node.
1285
1286    Args:
1287      recursive: Whether the listing is to be done recursively, as a boolean.
1288      node_name: The name of the node in question, as a str.
1289      depth: Maximum recursion depth, applies only if recursive == True, as an
1290        int.
1291      control: Whether control inputs or control recipients are included, as a
1292        boolean.
1293      op_type: Whether the op types of the nodes are to be included, as a
1294        boolean.
1295      do_outputs: Whether recipients, instead of input nodes are to be
1296        listed, as a boolean.
1297
1298    Returns:
1299      Input or recipient tree formatted as a RichTextLines object.
1300    """
1301
1302    if do_outputs:
1303      tracker = self._debug_dump.node_recipients
1304      type_str = "Recipients of"
1305      short_type_str = "recipients"
1306    else:
1307      tracker = self._debug_dump.node_inputs
1308      type_str = "Inputs to"
1309      short_type_str = "inputs"
1310
1311    lines = []
1312    font_attr_segs = {}
1313
1314    # Check if this is a tensor name, instead of a node name.
1315    node_name, _ = debug_graphs.parse_node_or_tensor_name(node_name)
1316
1317    # Check if node exists.
1318    if not self._debug_dump.node_exists(node_name):
1319      return cli_shared.error(
1320          "There is no node named \"%s\" in the partition graphs" % node_name)
1321
1322    if recursive:
1323      max_depth = depth
1324    else:
1325      max_depth = 1
1326
1327    if control:
1328      include_ctrls_str = ", control %s included" % short_type_str
1329    else:
1330      include_ctrls_str = ""
1331
1332    line = "%s node \"%s\"" % (type_str, node_name)
1333    font_attr_segs[0] = [(len(line) - 1 - len(node_name), len(line) - 1, "bold")
1334                        ]
1335    lines.append(line + " (Depth limit = %d%s):" % (max_depth, include_ctrls_str
1336                                                   ))
1337
1338    command_template = "lo -c -r %s" if do_outputs else "li -c -r %s"
1339    self._dfs_from_node(
1340        lines,
1341        font_attr_segs,
1342        node_name,
1343        tracker,
1344        max_depth,
1345        1, [],
1346        control,
1347        op_type,
1348        command_template=command_template)
1349
1350    # Include legend.
1351    lines.append("")
1352    lines.append("Legend:")
1353    lines.append("  (d): recursion depth = d.")
1354
1355    if control:
1356      lines.append("  (Ctrl): Control input.")
1357    if op_type:
1358      lines.append("  [Op]: Input node has op type Op.")
1359
1360    # TODO(cais): Consider appending ":0" at the end of 1st outputs of nodes.
1361
1362    return debugger_cli_common.RichTextLines(
1363        lines, font_attr_segs=font_attr_segs)
1364
1365  def _dfs_from_node(self,
1366                     lines,
1367                     attr_segs,
1368                     node_name,
1369                     tracker,
1370                     max_depth,
1371                     depth,
1372                     unfinished,
1373                     include_control=False,
1374                     show_op_type=False,
1375                     command_template=None):
1376    """Perform depth-first search (DFS) traversal of a node's input tree.
1377
1378    It recursively tracks the inputs (or output recipients) of the node called
1379    node_name, and append these inputs (or output recipients) to a list of text
1380    lines (lines) with proper indentation that reflects the recursion depth,
1381    together with some formatting attributes (to attr_segs). The formatting
1382    attributes can include command shortcuts, for example.
1383
1384    Args:
1385      lines: Text lines to append to, as a list of str.
1386      attr_segs: (dict) Attribute segments dictionary to append to.
1387      node_name: Name of the node, as a str. This arg is updated during the
1388        recursion.
1389      tracker: A callable that takes one str as the node name input and
1390        returns a list of str as the inputs/outputs.
1391        This makes it this function general enough to be used with both
1392        node-input and node-output tracking.
1393      max_depth: Maximum recursion depth, as an int.
1394      depth: Current recursion depth. This arg is updated during the
1395        recursion.
1396      unfinished: A stack of unfinished recursion depths, as a list of int.
1397      include_control: Whether control dependencies are to be included as
1398        inputs (and marked as such).
1399      show_op_type: Whether op type of the input nodes are to be displayed
1400        alongside the nodes' names.
1401      command_template: (str) Template for command shortcut of the node names.
1402    """
1403
1404    # Make a shallow copy of the list because it may be extended later.
1405    all_inputs = self._exclude_denylisted_ops(
1406        copy.copy(tracker(node_name, is_control=False)))
1407    is_ctrl = [False] * len(all_inputs)
1408    if include_control:
1409      # Sort control inputs or recipients in alphabetical order of the node
1410      # names.
1411      ctrl_inputs = self._exclude_denylisted_ops(
1412          sorted(tracker(node_name, is_control=True)))
1413      all_inputs.extend(ctrl_inputs)
1414      is_ctrl.extend([True] * len(ctrl_inputs))
1415
1416    if not all_inputs:
1417      if depth == 1:
1418        lines.append("  [None]")
1419
1420      return
1421
1422    unfinished.append(depth)
1423
1424    # Create depth-dependent hanging indent for the line.
1425    hang = ""
1426    for k in xrange(depth):
1427      if k < depth - 1:
1428        if k + 1 in unfinished:
1429          hang += HANG_UNFINISHED
1430        else:
1431          hang += HANG_FINISHED
1432      else:
1433        hang += HANG_SUFFIX
1434
1435    if all_inputs and depth > max_depth:
1436      lines.append(hang + ELLIPSIS)
1437      unfinished.pop()
1438      return
1439
1440    hang += DEPTH_TEMPLATE % depth
1441
1442    for i, inp in enumerate(all_inputs):
1443      op_type = self._debug_dump.node_op_type(debug_graphs.get_node_name(inp))
1444      if op_type in self._GRAPH_STRUCT_OP_TYPE_DENYLIST:
1445        continue
1446
1447      if is_ctrl[i]:
1448        ctrl_str = CTRL_LABEL
1449      else:
1450        ctrl_str = ""
1451
1452      op_type_str = ""
1453      if show_op_type:
1454        op_type_str = OP_TYPE_TEMPLATE % op_type
1455
1456      if i == len(all_inputs) - 1:
1457        unfinished.pop()
1458
1459      line = hang + ctrl_str + op_type_str + inp
1460      lines.append(line)
1461      if command_template:
1462        attr_segs[len(lines) - 1] = [(
1463            len(line) - len(inp), len(line),
1464            debugger_cli_common.MenuItem(None, command_template % inp))]
1465
1466      # Recursive call.
1467      # The input's/output's name can be a tensor name, in the case of node
1468      # with >1 output slots.
1469      inp_node_name, _ = debug_graphs.parse_node_or_tensor_name(inp)
1470      self._dfs_from_node(
1471          lines,
1472          attr_segs,
1473          inp_node_name,
1474          tracker,
1475          max_depth,
1476          depth + 1,
1477          unfinished,
1478          include_control=include_control,
1479          show_op_type=show_op_type,
1480          command_template=command_template)
1481
1482  def _format_neighbors(self, neighbor_type, non_ctrls, ctrls):
1483    """List neighbors (inputs or recipients) of a node.
1484
1485    Args:
1486      neighbor_type: ("input" | "recipient")
1487      non_ctrls: Non-control neighbor node names, as a list of str.
1488      ctrls: Control neighbor node names, as a list of str.
1489
1490    Returns:
1491      A RichTextLines object.
1492    """
1493
1494    # TODO(cais): Return RichTextLines instead, to allow annotation of node
1495    # names.
1496    lines = []
1497    font_attr_segs = {}
1498
1499    lines.append("")
1500    lines.append("  %d %s(s) + %d control %s(s):" %
1501                 (len(non_ctrls), neighbor_type, len(ctrls), neighbor_type))
1502    lines.append("    %d %s(s):" % (len(non_ctrls), neighbor_type))
1503    for non_ctrl in non_ctrls:
1504      line = "      [%s] %s" % (self._debug_dump.node_op_type(non_ctrl),
1505                                non_ctrl)
1506      lines.append(line)
1507      font_attr_segs[len(lines) - 1] = [(
1508          len(line) - len(non_ctrl), len(line),
1509          debugger_cli_common.MenuItem(None, "ni -a -d -t %s" % non_ctrl))]
1510
1511    if ctrls:
1512      lines.append("")
1513      lines.append("    %d control %s(s):" % (len(ctrls), neighbor_type))
1514      for ctrl in ctrls:
1515        line = "      [%s] %s" % (self._debug_dump.node_op_type(ctrl), ctrl)
1516        lines.append(line)
1517        font_attr_segs[len(lines) - 1] = [(
1518            len(line) - len(ctrl), len(line),
1519            debugger_cli_common.MenuItem(None, "ni -a -d -t %s" % ctrl))]
1520
1521    return debugger_cli_common.RichTextLines(
1522        lines, font_attr_segs=font_attr_segs)
1523
1524  def _list_node_attributes(self, node_name):
1525    """List neighbors (inputs or recipients) of a node.
1526
1527    Args:
1528      node_name: Name of the node of which the attributes are to be listed.
1529
1530    Returns:
1531      A RichTextLines object.
1532    """
1533
1534    lines = []
1535    lines.append("")
1536    lines.append("Node attributes:")
1537
1538    attrs = self._debug_dump.node_attributes(node_name)
1539    for attr_key in attrs:
1540      lines.append("  %s:" % attr_key)
1541      attr_val_str = repr(attrs[attr_key]).strip().replace("\n", " ")
1542      lines.append("    %s" % attr_val_str)
1543      lines.append("")
1544
1545    return debugger_cli_common.RichTextLines(lines)
1546
1547  def _list_node_dumps(self, node_name):
1548    """List dumped tensor data from a node.
1549
1550    Args:
1551      node_name: Name of the node of which the attributes are to be listed.
1552
1553    Returns:
1554      A RichTextLines object.
1555    """
1556
1557    lines = []
1558    font_attr_segs = {}
1559
1560    watch_keys = self._debug_dump.debug_watch_keys(node_name)
1561
1562    dump_count = 0
1563    for watch_key in watch_keys:
1564      debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key)
1565      for datum in debug_tensor_data:
1566        line = "  Slot %d @ %s @ %.3f ms" % (
1567            datum.output_slot, datum.debug_op,
1568            (datum.timestamp - self._debug_dump.t0) / 1000.0)
1569        lines.append(line)
1570        command = "pt %s:%d -n %d" % (node_name, datum.output_slot, dump_count)
1571        font_attr_segs[len(lines) - 1] = [(
1572            2, len(line), debugger_cli_common.MenuItem(None, command))]
1573        dump_count += 1
1574
1575    output = debugger_cli_common.RichTextLines(
1576        lines, font_attr_segs=font_attr_segs)
1577    output_with_header = debugger_cli_common.RichTextLines(
1578        ["%d dumped tensor(s):" % dump_count, ""])
1579    output_with_header.extend(output)
1580    return output_with_header
1581
1582
1583def create_analyzer_ui(debug_dump,
1584                       tensor_filters=None,
1585                       ui_type="curses",
1586                       on_ui_exit=None,
1587                       config=None):
1588  """Create an instance of CursesUI based on a DebugDumpDir object.
1589
1590  Args:
1591    debug_dump: (debug_data.DebugDumpDir) The debug dump to use.
1592    tensor_filters: (dict) A dict mapping tensor filter name (str) to tensor
1593      filter (Callable).
1594    ui_type: (str) requested UI type, e.g., "curses", "readline".
1595    on_ui_exit: (`Callable`) the callback to be called when the UI exits.
1596    config: A `cli_config.CLIConfig` object.
1597
1598  Returns:
1599    (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
1600      commands and tab-completions registered.
1601  """
1602  if config is None:
1603    config = cli_config.CLIConfig()
1604
1605  analyzer = DebugAnalyzer(debug_dump, config=config)
1606  if tensor_filters:
1607    for tensor_filter_name in tensor_filters:
1608      analyzer.add_tensor_filter(
1609          tensor_filter_name, tensor_filters[tensor_filter_name])
1610
1611  cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit, config=config)
1612  cli.register_command_handler(
1613      "list_tensors",
1614      analyzer.list_tensors,
1615      analyzer.get_help("list_tensors"),
1616      prefix_aliases=["lt"])
1617  cli.register_command_handler(
1618      "node_info",
1619      analyzer.node_info,
1620      analyzer.get_help("node_info"),
1621      prefix_aliases=["ni"])
1622  cli.register_command_handler(
1623      "list_inputs",
1624      analyzer.list_inputs,
1625      analyzer.get_help("list_inputs"),
1626      prefix_aliases=["li"])
1627  cli.register_command_handler(
1628      "list_outputs",
1629      analyzer.list_outputs,
1630      analyzer.get_help("list_outputs"),
1631      prefix_aliases=["lo"])
1632  cli.register_command_handler(
1633      "print_tensor",
1634      analyzer.print_tensor,
1635      analyzer.get_help("print_tensor"),
1636      prefix_aliases=["pt"])
1637  cli.register_command_handler(
1638      "print_source",
1639      analyzer.print_source,
1640      analyzer.get_help("print_source"),
1641      prefix_aliases=["ps"])
1642  cli.register_command_handler(
1643      "list_source",
1644      analyzer.list_source,
1645      analyzer.get_help("list_source"),
1646      prefix_aliases=["ls"])
1647  cli.register_command_handler(
1648      "eval",
1649      analyzer.evaluate_expression,
1650      analyzer.get_help("eval"),
1651      prefix_aliases=["ev"])
1652
1653  dumped_tensor_names = []
1654  for datum in debug_dump.dumped_tensor_data:
1655    dumped_tensor_names.append("%s:%d" % (datum.node_name, datum.output_slot))
1656
1657  # Tab completions for command "print_tensors".
1658  cli.register_tab_comp_context(["print_tensor", "pt"], dumped_tensor_names)
1659
1660  return cli
1661