• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Debugger Wrapper Session Consisting of a Local Curses-based CLI."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import argparse
21import os
22import sys
23import tempfile
24
25# Google-internal import(s).
26from tensorflow.python.debug.cli import analyzer_cli
27from tensorflow.python.debug.cli import cli_config
28from tensorflow.python.debug.cli import cli_shared
29from tensorflow.python.debug.cli import command_parser
30from tensorflow.python.debug.cli import debugger_cli_common
31from tensorflow.python.debug.cli import profile_analyzer_cli
32from tensorflow.python.debug.cli import ui_factory
33from tensorflow.python.debug.lib import common
34from tensorflow.python.debug.lib import debug_data
35from tensorflow.python.debug.wrappers import framework
36from tensorflow.python.lib.io import file_io
37
38
39_DUMP_ROOT_PREFIX = "tfdbg_"
40
41
42# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed.
43class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
44  """Concrete subclass of BaseDebugWrapperSession implementing a local CLI.
45
46  This class has all the methods that a `session.Session` object has, in order
47  to support debugging with minimal code changes. Invoking its `run()` method
48  will launch the command-line interface (CLI) of tfdbg.
49  """
50
51  def __init__(self,
52               sess,
53               dump_root=None,
54               log_usage=True,
55               ui_type="curses",
56               thread_name_filter=None,
57               config_file_path=False):
58    """Constructor of LocalCLIDebugWrapperSession.
59
60    Args:
61      sess: The TensorFlow `Session` object being wrapped.
62      dump_root: (`str`) optional path to the dump root directory. Must be a
63        directory that does not exist or an empty directory. If the directory
64        does not exist, it will be created by the debugger core during debug
65        `run()` calls and removed afterwards. If `None`, the debug dumps will
66        be at tfdbg_<random_string> under the system temp directory.
67      log_usage: (`bool`) whether the usage of this class is to be logged.
68      ui_type: (`str`) requested UI type. Currently supported:
69        (curses | readline)
70      thread_name_filter: Regular-expression white list for thread name. See
71        the doc of `BaseDebugWrapperSession` for details.
72      config_file_path: Optional override to the default configuration file
73        path, which is at `${HOME}/.tfdbg_config`.
74
75    Raises:
76      ValueError: If dump_root is an existing and non-empty directory or if
77        dump_root is a file.
78    """
79
80    if log_usage:
81      pass  # No logging for open-source.
82
83    framework.BaseDebugWrapperSession.__init__(
84        self, sess, thread_name_filter=thread_name_filter)
85
86    if not dump_root:
87      self._dump_root = tempfile.mktemp(prefix=_DUMP_ROOT_PREFIX)
88    else:
89      dump_root = os.path.expanduser(dump_root)
90      if os.path.isfile(dump_root):
91        raise ValueError("dump_root path points to a file: %s" % dump_root)
92      elif os.path.isdir(dump_root) and os.listdir(dump_root):
93        raise ValueError("dump_root path points to a non-empty directory: %s" %
94                         dump_root)
95
96      self._dump_root = dump_root
97
98    self._initialize_argparsers()
99
100    # Registered tensor filters.
101    self._tensor_filters = {}
102    # Register frequently-used filter(s).
103    self.add_tensor_filter("has_inf_or_nan", debug_data.has_inf_or_nan)
104
105    # Below are the state variables of this wrapper object.
106    # _active_tensor_filter: what (if any) tensor filter is in effect. If such
107    #   a filter is in effect, this object will call run() method of the
108    #   underlying TensorFlow Session object until the filter passes. This is
109    #   activated by the "-f" flag of the "run" command.
110    # _run_through_times: keeps track of how many times the wrapper needs to
111    #   run through without stopping at the run-end CLI. It is activated by the
112    #   "-t" option of the "run" command.
113    # _skip_debug: keeps track of whether the current run should be executed
114    #   without debugging. It is activated by the "-n" option of the "run"
115    #   command.
116    #
117    # _run_start_response: keeps track what OnRunStartResponse the wrapper
118    #   should return at the next run-start callback. If this information is
119    #   unavailable (i.e., is None), the run-start CLI will be launched to ask
120    #   the user. This is the case, e.g., right before the first run starts.
121    self._active_tensor_filter = None
122    self._active_filter_exclude_node_names = None
123    self._active_tensor_filter_run_start_response = None
124    self._run_through_times = 1
125    self._skip_debug = False
126    self._run_start_response = None
127    self._is_run_start = True
128    self._ui_type = ui_type
129    self._config = None
130    if config_file_path:
131      self._config = cli_config.CLIConfig(config_file_path=config_file_path)
132
133  def _is_disk_usage_reset_each_run(self):
134    # The dumped tensors are all cleaned up after every Session.run
135    # in a command-line wrapper.
136    return True
137
138  def _initialize_argparsers(self):
139    self._argparsers = {}
140    ap = argparse.ArgumentParser(
141        description="Run through, with or without debug tensor watching.",
142        usage=argparse.SUPPRESS)
143    ap.add_argument(
144        "-t",
145        "--times",
146        dest="times",
147        type=int,
148        default=1,
149        help="How many Session.run() calls to proceed with.")
150    ap.add_argument(
151        "-n",
152        "--no_debug",
153        dest="no_debug",
154        action="store_true",
155        help="Run through without debug tensor watching.")
156    ap.add_argument(
157        "-f",
158        "--till_filter_pass",
159        dest="till_filter_pass",
160        type=str,
161        default="",
162        help="Run until a tensor in the graph passes the specified filter.")
163    ap.add_argument(
164        "-fenn",
165        "--filter_exclude_node_names",
166        dest="filter_exclude_node_names",
167        type=str,
168        default="",
169        help="When applying the tensor filter, exclude node with names "
170        "matching the regular expression. Applicable only if --tensor_filter "
171        "or -f is used.")
172    ap.add_argument(
173        "--node_name_filter",
174        dest="node_name_filter",
175        type=str,
176        default="",
177        help="Regular-expression filter for node names to be watched in the "
178        "run, e.g., loss, reshape.*")
179    ap.add_argument(
180        "--op_type_filter",
181        dest="op_type_filter",
182        type=str,
183        default="",
184        help="Regular-expression filter for op type to be watched in the run, "
185        "e.g., (MatMul|Add), Variable.*")
186    ap.add_argument(
187        "--tensor_dtype_filter",
188        dest="tensor_dtype_filter",
189        type=str,
190        default="",
191        help="Regular-expression filter for tensor dtype to be watched in the "
192        "run, e.g., (float32|float64), int.*")
193    ap.add_argument(
194        "-p",
195        "--profile",
196        dest="profile",
197        action="store_true",
198        help="Run and profile TensorFlow graph execution.")
199    self._argparsers["run"] = ap
200
201    ap = argparse.ArgumentParser(
202        description="Display information about this Session.run() call.",
203        usage=argparse.SUPPRESS)
204    self._argparsers["run_info"] = ap
205
206    self._argparsers["print_feed"] = command_parser.get_print_tensor_argparser(
207        "Print the value of a feed in feed_dict.")
208
209  def add_tensor_filter(self, filter_name, tensor_filter):
210    """Add a tensor filter.
211
212    Args:
213      filter_name: (`str`) name of the filter.
214      tensor_filter: (`callable`) the filter callable. See the doc string of
215        `DebugDumpDir.find()` for more details about its signature.
216    """
217
218    self._tensor_filters[filter_name] = tensor_filter
219
220  def on_session_init(self, request):
221    """Overrides on-session-init callback.
222
223    Args:
224      request: An instance of `OnSessionInitRequest`.
225
226    Returns:
227      An instance of `OnSessionInitResponse`.
228    """
229
230    return framework.OnSessionInitResponse(
231        framework.OnSessionInitAction.PROCEED)
232
233  def on_run_start(self, request):
234    """Overrides on-run-start callback.
235
236    Args:
237      request: An instance of `OnRunStartRequest`.
238
239    Returns:
240      An instance of `OnRunStartResponse`.
241    """
242    self._is_run_start = True
243    self._update_run_calls_state(
244        request.run_call_count, request.fetches, request.feed_dict,
245        is_callable_runner=request.is_callable_runner)
246
247    if self._active_tensor_filter:
248      # If we are running until a filter passes, we just need to keep running
249      # with the previous `OnRunStartResponse`.
250      return self._active_tensor_filter_run_start_response
251
252    self._exit_if_requested_by_user()
253
254    if self._run_call_count > 1 and not self._skip_debug:
255      if self._run_through_times > 0:
256        # Just run through without debugging.
257        return framework.OnRunStartResponse(
258            framework.OnRunStartAction.NON_DEBUG_RUN, [])
259      elif self._run_through_times == 0:
260        # It is the run at which the run-end CLI will be launched: activate
261        # debugging.
262        return (self._run_start_response or
263                framework.OnRunStartResponse(
264                    framework.OnRunStartAction.DEBUG_RUN,
265                    self._get_run_debug_urls()))
266
267    if self._run_start_response is None:
268      self._prep_cli_for_run_start()
269
270      self._run_start_response = self._launch_cli()
271      if self._active_tensor_filter:
272        self._active_tensor_filter_run_start_response = self._run_start_response
273      if self._run_through_times > 1:
274        self._run_through_times -= 1
275
276    self._exit_if_requested_by_user()
277    return self._run_start_response
278
279  def _exit_if_requested_by_user(self):
280    if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT:
281      # Explicit user "exit" command leads to sys.exit(1).
282      print(
283          "Note: user exited from debugger CLI: Calling sys.exit(1).",
284          file=sys.stderr)
285      sys.exit(1)
286
287  def _prep_cli_for_run_start(self):
288    """Prepare (but not launch) the CLI for run-start."""
289    self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config)
290
291    help_intro = debugger_cli_common.RichTextLines([])
292    if self._run_call_count == 1:
293      # Show logo at the onset of the first run.
294      help_intro.extend(cli_shared.get_tfdbg_logo())
295      help_intro.extend(debugger_cli_common.get_tensorflow_version_lines())
296    help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:"))
297    help_intro.extend(self._run_info)
298
299    self._run_cli.set_help_intro(help_intro)
300
301    # Create initial screen output detailing the run.
302    self._title = "run-start: " + self._run_description
303    self._init_command = "run_info"
304    self._title_color = "blue_on_white"
305
306  def on_run_end(self, request):
307    """Overrides on-run-end callback.
308
309    Actions taken:
310      1) Load the debug dump.
311      2) Bring up the Analyzer CLI.
312
313    Args:
314      request: An instance of OnSessionInitRequest.
315
316    Returns:
317      An instance of OnSessionInitResponse.
318    """
319
320    self._is_run_start = False
321    if request.performed_action == framework.OnRunStartAction.DEBUG_RUN:
322      partition_graphs = None
323      if request.run_metadata and request.run_metadata.partition_graphs:
324        partition_graphs = request.run_metadata.partition_graphs
325      elif request.client_graph_def:
326        partition_graphs = [request.client_graph_def]
327
328      if request.tf_error and not os.path.isdir(self._dump_root):
329        # It is possible that the dump root may not exist due to errors that
330        # have occurred prior to graph execution (e.g., invalid device
331        # assignments), in which case we will just raise the exception as the
332        # unwrapped Session does.
333        raise request.tf_error
334
335      debug_dump = debug_data.DebugDumpDir(
336          self._dump_root, partition_graphs=partition_graphs)
337      debug_dump.set_python_graph(self._sess.graph)
338
339      passed_filter = None
340      passed_filter_exclude_node_names = None
341      if self._active_tensor_filter:
342        if not debug_dump.find(
343            self._tensor_filters[self._active_tensor_filter], first_n=1,
344            exclude_node_names=self._active_filter_exclude_node_names):
345          # No dumped tensor passes the filter in this run. Clean up the dump
346          # directory and move on.
347          self._remove_dump_root()
348          return framework.OnRunEndResponse()
349        else:
350          # Some dumped tensor(s) from this run passed the filter.
351          passed_filter = self._active_tensor_filter
352          passed_filter_exclude_node_names = (
353              self._active_filter_exclude_node_names)
354          self._active_tensor_filter = None
355          self._active_filter_exclude_node_names = None
356
357      self._prep_debug_cli_for_run_end(
358          debug_dump, request.tf_error, passed_filter,
359          passed_filter_exclude_node_names)
360
361      self._run_start_response = self._launch_cli()
362
363      # Clean up the dump generated by this run.
364      self._remove_dump_root()
365    elif request.performed_action == framework.OnRunStartAction.PROFILE_RUN:
366      self._prep_profile_cli_for_run_end(self._sess.graph, request.run_metadata)
367      self._run_start_response = self._launch_cli()
368    else:
369      # No debug information to show following a non-debug run() call.
370      self._run_start_response = None
371
372    # Return placeholder response that currently holds no additional
373    # information.
374    return framework.OnRunEndResponse()
375
376  def _remove_dump_root(self):
377    if os.path.isdir(self._dump_root):
378      file_io.delete_recursively(self._dump_root)
379
380  def _prep_debug_cli_for_run_end(self,
381                                  debug_dump,
382                                  tf_error,
383                                  passed_filter,
384                                  passed_filter_exclude_node_names):
385    """Prepare (but not launch) CLI for run-end, with debug dump from the run.
386
387    Args:
388      debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this
389        run.
390      tf_error: (None or OpError) OpError that happened during the run() call
391        (if any).
392      passed_filter: (None or str) Name of the tensor filter that just passed
393        and caused the preparation of this run-end CLI (if any).
394      passed_filter_exclude_node_names: (None or str) Regular expression used
395        with the tensor filter to exclude ops with names matching the regular
396        expression.
397    """
398
399    if tf_error:
400      help_intro = cli_shared.get_error_intro(tf_error)
401
402      self._init_command = "help"
403      self._title_color = "red_on_white"
404    else:
405      help_intro = None
406      self._init_command = "lt"
407
408      self._title_color = "black_on_white"
409      if passed_filter is not None:
410        # Some dumped tensor(s) from this run passed the filter.
411        self._init_command = "lt -f %s" % passed_filter
412        if passed_filter_exclude_node_names:
413          self._init_command += (" --filter_exclude_node_names %s" %
414                                 passed_filter_exclude_node_names)
415        self._title_color = "red_on_white"
416
417    self._run_cli = analyzer_cli.create_analyzer_ui(
418        debug_dump,
419        self._tensor_filters,
420        ui_type=self._ui_type,
421        on_ui_exit=self._remove_dump_root,
422        config=self._config)
423
424    # Get names of all dumped tensors.
425    dumped_tensor_names = []
426    for datum in debug_dump.dumped_tensor_data:
427      dumped_tensor_names.append("%s:%d" %
428                                 (datum.node_name, datum.output_slot))
429
430    # Tab completions for command "print_tensors".
431    self._run_cli.register_tab_comp_context(["print_tensor", "pt"],
432                                            dumped_tensor_names)
433
434    # Tab completion for commands "node_info", "list_inputs" and
435    # "list_outputs". The list comprehension is used below because nodes()
436    # output can be unicodes and they need to be converted to strs.
437    self._run_cli.register_tab_comp_context(
438        ["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
439        [str(node_name) for node_name in debug_dump.nodes()])
440    # TODO(cais): Reduce API surface area for aliases vis-a-vis tab
441    #    completion contexts and registered command handlers.
442
443    self._title = "run-end: " + self._run_description
444
445    if help_intro:
446      self._run_cli.set_help_intro(help_intro)
447
448  def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
449    self._init_command = "lp"
450    self._run_cli = profile_analyzer_cli.create_profiler_ui(
451        py_graph, run_metadata, ui_type=self._ui_type,
452        config=self._run_cli.config)
453    self._title = "run-end (profiler mode): " + self._run_description
454
455  def _launch_cli(self):
456    """Launch the interactive command-line interface.
457
458    Returns:
459      The OnRunStartResponse specified by the user using the "run" command.
460    """
461
462    self._register_this_run_info(self._run_cli)
463    response = self._run_cli.run_ui(
464        init_command=self._init_command,
465        title=self._title,
466        title_color=self._title_color)
467
468    return response
469
470  def _run_info_handler(self, args, screen_info=None):
471    output = debugger_cli_common.RichTextLines([])
472
473    if self._run_call_count == 1:
474      output.extend(cli_shared.get_tfdbg_logo())
475      output.extend(debugger_cli_common.get_tensorflow_version_lines())
476    output.extend(self._run_info)
477
478    if (not self._is_run_start and
479        debugger_cli_common.MAIN_MENU_KEY in output.annotations):
480      menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY]
481      if "list_tensors" not in menu.captions():
482        menu.insert(
483            0, debugger_cli_common.MenuItem("list_tensors", "list_tensors"))
484
485    return output
486
487  def _print_feed_handler(self, args, screen_info=None):
488    np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
489        screen_info)
490
491    if not self._feed_dict:
492      return cli_shared.error(
493          "The feed_dict of the current run is None or empty.")
494
495    parsed = self._argparsers["print_feed"].parse_args(args)
496    tensor_name, tensor_slicing = (
497        command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
498
499    feed_key = None
500    feed_value = None
501    for key in self._feed_dict:
502      key_name = common.get_graph_element_name(key)
503      if key_name == tensor_name:
504        feed_key = key_name
505        feed_value = self._feed_dict[key]
506        break
507
508    if feed_key is None:
509      return cli_shared.error(
510          "The feed_dict of the current run does not contain the key %s" %
511          tensor_name)
512    else:
513      return cli_shared.format_tensor(
514          feed_value,
515          feed_key + " (feed)",
516          np_printoptions,
517          print_all=parsed.print_all,
518          tensor_slicing=tensor_slicing,
519          highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges),
520          include_numeric_summary=parsed.numeric_summary)
521
522  def _run_handler(self, args, screen_info=None):
523    """Command handler for "run" command during on-run-start."""
524
525    del screen_info  # Currently unused.
526
527    parsed = self._argparsers["run"].parse_args(args)
528    parsed.node_name_filter = parsed.node_name_filter or None
529    parsed.op_type_filter = parsed.op_type_filter or None
530    parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None
531
532    if parsed.filter_exclude_node_names and not parsed.till_filter_pass:
533      raise ValueError(
534          "The --filter_exclude_node_names (or -feon) flag is valid only if "
535          "the --till_filter_pass (or -f) flag is used.")
536
537    if parsed.profile:
538      raise debugger_cli_common.CommandLineExit(
539          exit_token=framework.OnRunStartResponse(
540              framework.OnRunStartAction.PROFILE_RUN, []))
541
542    self._skip_debug = parsed.no_debug
543    self._run_through_times = parsed.times
544
545    if parsed.times > 1 or parsed.no_debug:
546      # If requested -t times > 1, the very next run will be a non-debug run.
547      action = framework.OnRunStartAction.NON_DEBUG_RUN
548      debug_urls = []
549    else:
550      action = framework.OnRunStartAction.DEBUG_RUN
551      debug_urls = self._get_run_debug_urls()
552    run_start_response = framework.OnRunStartResponse(
553        action,
554        debug_urls,
555        node_name_regex_allowlist=parsed.node_name_filter,
556        op_type_regex_allowlist=parsed.op_type_filter,
557        tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter)
558
559    if parsed.till_filter_pass:
560      # For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN
561      # option to access the intermediate tensors, and set the corresponding
562      # state flag of the class itself to True.
563      if parsed.till_filter_pass in self._tensor_filters:
564        action = framework.OnRunStartAction.DEBUG_RUN
565        self._active_tensor_filter = parsed.till_filter_pass
566        self._active_filter_exclude_node_names = (
567            parsed.filter_exclude_node_names)
568        self._active_tensor_filter_run_start_response = run_start_response
569      else:
570        # Handle invalid filter name.
571        return debugger_cli_common.RichTextLines(
572            ["ERROR: tensor filter \"%s\" does not exist." %
573             parsed.till_filter_pass])
574
575    # Raise CommandLineExit exception to cause the CLI to exit.
576    raise debugger_cli_common.CommandLineExit(exit_token=run_start_response)
577
578  def _register_this_run_info(self, curses_cli):
579    curses_cli.register_command_handler(
580        "run",
581        self._run_handler,
582        self._argparsers["run"].format_help(),
583        prefix_aliases=["r"])
584    curses_cli.register_command_handler(
585        "run_info",
586        self._run_info_handler,
587        self._argparsers["run_info"].format_help(),
588        prefix_aliases=["ri"])
589    curses_cli.register_command_handler(
590        "print_feed",
591        self._print_feed_handler,
592        self._argparsers["print_feed"].format_help(),
593        prefix_aliases=["pf"])
594
595    if self._tensor_filters:
596      # Register tab completion for the filter names.
597      curses_cli.register_tab_comp_context(["run", "r"],
598                                           list(self._tensor_filters.keys()))
599    if self._feed_dict and hasattr(self._feed_dict, "keys"):
600      # Register tab completion for feed_dict keys.
601      feed_keys = [common.get_graph_element_name(key)
602                   for key in self._feed_dict.keys()]
603      curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)
604
605  def _get_run_debug_urls(self):
606    """Get the debug_urls value for the current run() call.
607
608    Returns:
609      debug_urls: (list of str) Debug URLs for the current run() call.
610        Currently, the list consists of only one URL that is a file:// URL.
611    """
612
613    return ["file://" + self._dump_root]
614
615  def _update_run_calls_state(self,
616                              run_call_count,
617                              fetches,
618                              feed_dict,
619                              is_callable_runner=False):
620    """Update the internal state with regard to run() call history.
621
622    Args:
623      run_call_count: (int) Number of run() calls that have occurred.
624      fetches: a node/tensor or a list of node/tensor that are the fetches of
625        the run() call. This is the same as the fetches argument to the run()
626        call.
627      feed_dict: None of a dict. This is the feed_dict argument to the run()
628        call.
629      is_callable_runner: (bool) whether a runner returned by
630        Session.make_callable is being run.
631    """
632
633    self._run_call_count = run_call_count
634    self._feed_dict = feed_dict
635    self._run_description = cli_shared.get_run_short_description(
636        run_call_count,
637        fetches,
638        feed_dict,
639        is_callable_runner=is_callable_runner)
640    self._run_through_times -= 1
641
642    self._run_info = cli_shared.get_run_start_intro(
643        run_call_count,
644        fetches,
645        feed_dict,
646        self._tensor_filters,
647        is_callable_runner=is_callable_runner)
648