1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Debugger Wrapper Session Consisting of a Local Curses-based CLI.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import argparse 21import os 22import sys 23import tempfile 24 25# Google-internal import(s). 26from tensorflow.python.debug.cli import analyzer_cli 27from tensorflow.python.debug.cli import cli_config 28from tensorflow.python.debug.cli import cli_shared 29from tensorflow.python.debug.cli import command_parser 30from tensorflow.python.debug.cli import debugger_cli_common 31from tensorflow.python.debug.cli import profile_analyzer_cli 32from tensorflow.python.debug.cli import ui_factory 33from tensorflow.python.debug.lib import common 34from tensorflow.python.debug.lib import debug_data 35from tensorflow.python.debug.wrappers import framework 36from tensorflow.python.lib.io import file_io 37 38 39_DUMP_ROOT_PREFIX = "tfdbg_" 40 41 42# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed. 43class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): 44 """Concrete subclass of BaseDebugWrapperSession implementing a local CLI. 45 46 This class has all the methods that a `session.Session` object has, in order 47 to support debugging with minimal code changes. Invoking its `run()` method 48 will launch the command-line interface (CLI) of tfdbg. 49 """ 50 51 def __init__(self, 52 sess, 53 dump_root=None, 54 log_usage=True, 55 ui_type="curses", 56 thread_name_filter=None, 57 config_file_path=False): 58 """Constructor of LocalCLIDebugWrapperSession. 59 60 Args: 61 sess: The TensorFlow `Session` object being wrapped. 62 dump_root: (`str`) optional path to the dump root directory. Must be a 63 directory that does not exist or an empty directory. If the directory 64 does not exist, it will be created by the debugger core during debug 65 `run()` calls and removed afterwards. If `None`, the debug dumps will 66 be at tfdbg_<random_string> under the system temp directory. 67 log_usage: (`bool`) whether the usage of this class is to be logged. 68 ui_type: (`str`) requested UI type. Currently supported: 69 (curses | readline) 70 thread_name_filter: Regular-expression white list for thread name. See 71 the doc of `BaseDebugWrapperSession` for details. 72 config_file_path: Optional override to the default configuration file 73 path, which is at `${HOME}/.tfdbg_config`. 74 75 Raises: 76 ValueError: If dump_root is an existing and non-empty directory or if 77 dump_root is a file. 78 """ 79 80 if log_usage: 81 pass # No logging for open-source. 82 83 framework.BaseDebugWrapperSession.__init__( 84 self, sess, thread_name_filter=thread_name_filter) 85 86 if not dump_root: 87 self._dump_root = tempfile.mktemp(prefix=_DUMP_ROOT_PREFIX) 88 else: 89 dump_root = os.path.expanduser(dump_root) 90 if os.path.isfile(dump_root): 91 raise ValueError("dump_root path points to a file: %s" % dump_root) 92 elif os.path.isdir(dump_root) and os.listdir(dump_root): 93 raise ValueError("dump_root path points to a non-empty directory: %s" % 94 dump_root) 95 96 self._dump_root = dump_root 97 98 self._initialize_argparsers() 99 100 # Registered tensor filters. 101 self._tensor_filters = {} 102 # Register frequently-used filter(s). 103 self.add_tensor_filter("has_inf_or_nan", debug_data.has_inf_or_nan) 104 105 # Below are the state variables of this wrapper object. 106 # _active_tensor_filter: what (if any) tensor filter is in effect. If such 107 # a filter is in effect, this object will call run() method of the 108 # underlying TensorFlow Session object until the filter passes. This is 109 # activated by the "-f" flag of the "run" command. 110 # _run_through_times: keeps track of how many times the wrapper needs to 111 # run through without stopping at the run-end CLI. It is activated by the 112 # "-t" option of the "run" command. 113 # _skip_debug: keeps track of whether the current run should be executed 114 # without debugging. It is activated by the "-n" option of the "run" 115 # command. 116 # 117 # _run_start_response: keeps track what OnRunStartResponse the wrapper 118 # should return at the next run-start callback. If this information is 119 # unavailable (i.e., is None), the run-start CLI will be launched to ask 120 # the user. This is the case, e.g., right before the first run starts. 121 self._active_tensor_filter = None 122 self._active_filter_exclude_node_names = None 123 self._active_tensor_filter_run_start_response = None 124 self._run_through_times = 1 125 self._skip_debug = False 126 self._run_start_response = None 127 self._is_run_start = True 128 self._ui_type = ui_type 129 self._config = None 130 if config_file_path: 131 self._config = cli_config.CLIConfig(config_file_path=config_file_path) 132 133 def _is_disk_usage_reset_each_run(self): 134 # The dumped tensors are all cleaned up after every Session.run 135 # in a command-line wrapper. 136 return True 137 138 def _initialize_argparsers(self): 139 self._argparsers = {} 140 ap = argparse.ArgumentParser( 141 description="Run through, with or without debug tensor watching.", 142 usage=argparse.SUPPRESS) 143 ap.add_argument( 144 "-t", 145 "--times", 146 dest="times", 147 type=int, 148 default=1, 149 help="How many Session.run() calls to proceed with.") 150 ap.add_argument( 151 "-n", 152 "--no_debug", 153 dest="no_debug", 154 action="store_true", 155 help="Run through without debug tensor watching.") 156 ap.add_argument( 157 "-f", 158 "--till_filter_pass", 159 dest="till_filter_pass", 160 type=str, 161 default="", 162 help="Run until a tensor in the graph passes the specified filter.") 163 ap.add_argument( 164 "-fenn", 165 "--filter_exclude_node_names", 166 dest="filter_exclude_node_names", 167 type=str, 168 default="", 169 help="When applying the tensor filter, exclude node with names " 170 "matching the regular expression. Applicable only if --tensor_filter " 171 "or -f is used.") 172 ap.add_argument( 173 "--node_name_filter", 174 dest="node_name_filter", 175 type=str, 176 default="", 177 help="Regular-expression filter for node names to be watched in the " 178 "run, e.g., loss, reshape.*") 179 ap.add_argument( 180 "--op_type_filter", 181 dest="op_type_filter", 182 type=str, 183 default="", 184 help="Regular-expression filter for op type to be watched in the run, " 185 "e.g., (MatMul|Add), Variable.*") 186 ap.add_argument( 187 "--tensor_dtype_filter", 188 dest="tensor_dtype_filter", 189 type=str, 190 default="", 191 help="Regular-expression filter for tensor dtype to be watched in the " 192 "run, e.g., (float32|float64), int.*") 193 ap.add_argument( 194 "-p", 195 "--profile", 196 dest="profile", 197 action="store_true", 198 help="Run and profile TensorFlow graph execution.") 199 self._argparsers["run"] = ap 200 201 ap = argparse.ArgumentParser( 202 description="Display information about this Session.run() call.", 203 usage=argparse.SUPPRESS) 204 self._argparsers["run_info"] = ap 205 206 self._argparsers["print_feed"] = command_parser.get_print_tensor_argparser( 207 "Print the value of a feed in feed_dict.") 208 209 def add_tensor_filter(self, filter_name, tensor_filter): 210 """Add a tensor filter. 211 212 Args: 213 filter_name: (`str`) name of the filter. 214 tensor_filter: (`callable`) the filter callable. See the doc string of 215 `DebugDumpDir.find()` for more details about its signature. 216 """ 217 218 self._tensor_filters[filter_name] = tensor_filter 219 220 def on_session_init(self, request): 221 """Overrides on-session-init callback. 222 223 Args: 224 request: An instance of `OnSessionInitRequest`. 225 226 Returns: 227 An instance of `OnSessionInitResponse`. 228 """ 229 230 return framework.OnSessionInitResponse( 231 framework.OnSessionInitAction.PROCEED) 232 233 def on_run_start(self, request): 234 """Overrides on-run-start callback. 235 236 Args: 237 request: An instance of `OnRunStartRequest`. 238 239 Returns: 240 An instance of `OnRunStartResponse`. 241 """ 242 self._is_run_start = True 243 self._update_run_calls_state( 244 request.run_call_count, request.fetches, request.feed_dict, 245 is_callable_runner=request.is_callable_runner) 246 247 if self._active_tensor_filter: 248 # If we are running until a filter passes, we just need to keep running 249 # with the previous `OnRunStartResponse`. 250 return self._active_tensor_filter_run_start_response 251 252 self._exit_if_requested_by_user() 253 254 if self._run_call_count > 1 and not self._skip_debug: 255 if self._run_through_times > 0: 256 # Just run through without debugging. 257 return framework.OnRunStartResponse( 258 framework.OnRunStartAction.NON_DEBUG_RUN, []) 259 elif self._run_through_times == 0: 260 # It is the run at which the run-end CLI will be launched: activate 261 # debugging. 262 return (self._run_start_response or 263 framework.OnRunStartResponse( 264 framework.OnRunStartAction.DEBUG_RUN, 265 self._get_run_debug_urls())) 266 267 if self._run_start_response is None: 268 self._prep_cli_for_run_start() 269 270 self._run_start_response = self._launch_cli() 271 if self._active_tensor_filter: 272 self._active_tensor_filter_run_start_response = self._run_start_response 273 if self._run_through_times > 1: 274 self._run_through_times -= 1 275 276 self._exit_if_requested_by_user() 277 return self._run_start_response 278 279 def _exit_if_requested_by_user(self): 280 if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT: 281 # Explicit user "exit" command leads to sys.exit(1). 282 print( 283 "Note: user exited from debugger CLI: Calling sys.exit(1).", 284 file=sys.stderr) 285 sys.exit(1) 286 287 def _prep_cli_for_run_start(self): 288 """Prepare (but not launch) the CLI for run-start.""" 289 self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config) 290 291 help_intro = debugger_cli_common.RichTextLines([]) 292 if self._run_call_count == 1: 293 # Show logo at the onset of the first run. 294 help_intro.extend(cli_shared.get_tfdbg_logo()) 295 help_intro.extend(debugger_cli_common.get_tensorflow_version_lines()) 296 help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:")) 297 help_intro.extend(self._run_info) 298 299 self._run_cli.set_help_intro(help_intro) 300 301 # Create initial screen output detailing the run. 302 self._title = "run-start: " + self._run_description 303 self._init_command = "run_info" 304 self._title_color = "blue_on_white" 305 306 def on_run_end(self, request): 307 """Overrides on-run-end callback. 308 309 Actions taken: 310 1) Load the debug dump. 311 2) Bring up the Analyzer CLI. 312 313 Args: 314 request: An instance of OnSessionInitRequest. 315 316 Returns: 317 An instance of OnSessionInitResponse. 318 """ 319 320 self._is_run_start = False 321 if request.performed_action == framework.OnRunStartAction.DEBUG_RUN: 322 partition_graphs = None 323 if request.run_metadata and request.run_metadata.partition_graphs: 324 partition_graphs = request.run_metadata.partition_graphs 325 elif request.client_graph_def: 326 partition_graphs = [request.client_graph_def] 327 328 if request.tf_error and not os.path.isdir(self._dump_root): 329 # It is possible that the dump root may not exist due to errors that 330 # have occurred prior to graph execution (e.g., invalid device 331 # assignments), in which case we will just raise the exception as the 332 # unwrapped Session does. 333 raise request.tf_error 334 335 debug_dump = debug_data.DebugDumpDir( 336 self._dump_root, partition_graphs=partition_graphs) 337 debug_dump.set_python_graph(self._sess.graph) 338 339 passed_filter = None 340 passed_filter_exclude_node_names = None 341 if self._active_tensor_filter: 342 if not debug_dump.find( 343 self._tensor_filters[self._active_tensor_filter], first_n=1, 344 exclude_node_names=self._active_filter_exclude_node_names): 345 # No dumped tensor passes the filter in this run. Clean up the dump 346 # directory and move on. 347 self._remove_dump_root() 348 return framework.OnRunEndResponse() 349 else: 350 # Some dumped tensor(s) from this run passed the filter. 351 passed_filter = self._active_tensor_filter 352 passed_filter_exclude_node_names = ( 353 self._active_filter_exclude_node_names) 354 self._active_tensor_filter = None 355 self._active_filter_exclude_node_names = None 356 357 self._prep_debug_cli_for_run_end( 358 debug_dump, request.tf_error, passed_filter, 359 passed_filter_exclude_node_names) 360 361 self._run_start_response = self._launch_cli() 362 363 # Clean up the dump generated by this run. 364 self._remove_dump_root() 365 elif request.performed_action == framework.OnRunStartAction.PROFILE_RUN: 366 self._prep_profile_cli_for_run_end(self._sess.graph, request.run_metadata) 367 self._run_start_response = self._launch_cli() 368 else: 369 # No debug information to show following a non-debug run() call. 370 self._run_start_response = None 371 372 # Return placeholder response that currently holds no additional 373 # information. 374 return framework.OnRunEndResponse() 375 376 def _remove_dump_root(self): 377 if os.path.isdir(self._dump_root): 378 file_io.delete_recursively(self._dump_root) 379 380 def _prep_debug_cli_for_run_end(self, 381 debug_dump, 382 tf_error, 383 passed_filter, 384 passed_filter_exclude_node_names): 385 """Prepare (but not launch) CLI for run-end, with debug dump from the run. 386 387 Args: 388 debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this 389 run. 390 tf_error: (None or OpError) OpError that happened during the run() call 391 (if any). 392 passed_filter: (None or str) Name of the tensor filter that just passed 393 and caused the preparation of this run-end CLI (if any). 394 passed_filter_exclude_node_names: (None or str) Regular expression used 395 with the tensor filter to exclude ops with names matching the regular 396 expression. 397 """ 398 399 if tf_error: 400 help_intro = cli_shared.get_error_intro(tf_error) 401 402 self._init_command = "help" 403 self._title_color = "red_on_white" 404 else: 405 help_intro = None 406 self._init_command = "lt" 407 408 self._title_color = "black_on_white" 409 if passed_filter is not None: 410 # Some dumped tensor(s) from this run passed the filter. 411 self._init_command = "lt -f %s" % passed_filter 412 if passed_filter_exclude_node_names: 413 self._init_command += (" --filter_exclude_node_names %s" % 414 passed_filter_exclude_node_names) 415 self._title_color = "red_on_white" 416 417 self._run_cli = analyzer_cli.create_analyzer_ui( 418 debug_dump, 419 self._tensor_filters, 420 ui_type=self._ui_type, 421 on_ui_exit=self._remove_dump_root, 422 config=self._config) 423 424 # Get names of all dumped tensors. 425 dumped_tensor_names = [] 426 for datum in debug_dump.dumped_tensor_data: 427 dumped_tensor_names.append("%s:%d" % 428 (datum.node_name, datum.output_slot)) 429 430 # Tab completions for command "print_tensors". 431 self._run_cli.register_tab_comp_context(["print_tensor", "pt"], 432 dumped_tensor_names) 433 434 # Tab completion for commands "node_info", "list_inputs" and 435 # "list_outputs". The list comprehension is used below because nodes() 436 # output can be unicodes and they need to be converted to strs. 437 self._run_cli.register_tab_comp_context( 438 ["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"], 439 [str(node_name) for node_name in debug_dump.nodes()]) 440 # TODO(cais): Reduce API surface area for aliases vis-a-vis tab 441 # completion contexts and registered command handlers. 442 443 self._title = "run-end: " + self._run_description 444 445 if help_intro: 446 self._run_cli.set_help_intro(help_intro) 447 448 def _prep_profile_cli_for_run_end(self, py_graph, run_metadata): 449 self._init_command = "lp" 450 self._run_cli = profile_analyzer_cli.create_profiler_ui( 451 py_graph, run_metadata, ui_type=self._ui_type, 452 config=self._run_cli.config) 453 self._title = "run-end (profiler mode): " + self._run_description 454 455 def _launch_cli(self): 456 """Launch the interactive command-line interface. 457 458 Returns: 459 The OnRunStartResponse specified by the user using the "run" command. 460 """ 461 462 self._register_this_run_info(self._run_cli) 463 response = self._run_cli.run_ui( 464 init_command=self._init_command, 465 title=self._title, 466 title_color=self._title_color) 467 468 return response 469 470 def _run_info_handler(self, args, screen_info=None): 471 output = debugger_cli_common.RichTextLines([]) 472 473 if self._run_call_count == 1: 474 output.extend(cli_shared.get_tfdbg_logo()) 475 output.extend(debugger_cli_common.get_tensorflow_version_lines()) 476 output.extend(self._run_info) 477 478 if (not self._is_run_start and 479 debugger_cli_common.MAIN_MENU_KEY in output.annotations): 480 menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY] 481 if "list_tensors" not in menu.captions(): 482 menu.insert( 483 0, debugger_cli_common.MenuItem("list_tensors", "list_tensors")) 484 485 return output 486 487 def _print_feed_handler(self, args, screen_info=None): 488 np_printoptions = cli_shared.numpy_printoptions_from_screen_info( 489 screen_info) 490 491 if not self._feed_dict: 492 return cli_shared.error( 493 "The feed_dict of the current run is None or empty.") 494 495 parsed = self._argparsers["print_feed"].parse_args(args) 496 tensor_name, tensor_slicing = ( 497 command_parser.parse_tensor_name_with_slicing(parsed.tensor_name)) 498 499 feed_key = None 500 feed_value = None 501 for key in self._feed_dict: 502 key_name = common.get_graph_element_name(key) 503 if key_name == tensor_name: 504 feed_key = key_name 505 feed_value = self._feed_dict[key] 506 break 507 508 if feed_key is None: 509 return cli_shared.error( 510 "The feed_dict of the current run does not contain the key %s" % 511 tensor_name) 512 else: 513 return cli_shared.format_tensor( 514 feed_value, 515 feed_key + " (feed)", 516 np_printoptions, 517 print_all=parsed.print_all, 518 tensor_slicing=tensor_slicing, 519 highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges), 520 include_numeric_summary=parsed.numeric_summary) 521 522 def _run_handler(self, args, screen_info=None): 523 """Command handler for "run" command during on-run-start.""" 524 525 del screen_info # Currently unused. 526 527 parsed = self._argparsers["run"].parse_args(args) 528 parsed.node_name_filter = parsed.node_name_filter or None 529 parsed.op_type_filter = parsed.op_type_filter or None 530 parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None 531 532 if parsed.filter_exclude_node_names and not parsed.till_filter_pass: 533 raise ValueError( 534 "The --filter_exclude_node_names (or -feon) flag is valid only if " 535 "the --till_filter_pass (or -f) flag is used.") 536 537 if parsed.profile: 538 raise debugger_cli_common.CommandLineExit( 539 exit_token=framework.OnRunStartResponse( 540 framework.OnRunStartAction.PROFILE_RUN, [])) 541 542 self._skip_debug = parsed.no_debug 543 self._run_through_times = parsed.times 544 545 if parsed.times > 1 or parsed.no_debug: 546 # If requested -t times > 1, the very next run will be a non-debug run. 547 action = framework.OnRunStartAction.NON_DEBUG_RUN 548 debug_urls = [] 549 else: 550 action = framework.OnRunStartAction.DEBUG_RUN 551 debug_urls = self._get_run_debug_urls() 552 run_start_response = framework.OnRunStartResponse( 553 action, 554 debug_urls, 555 node_name_regex_allowlist=parsed.node_name_filter, 556 op_type_regex_allowlist=parsed.op_type_filter, 557 tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter) 558 559 if parsed.till_filter_pass: 560 # For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN 561 # option to access the intermediate tensors, and set the corresponding 562 # state flag of the class itself to True. 563 if parsed.till_filter_pass in self._tensor_filters: 564 action = framework.OnRunStartAction.DEBUG_RUN 565 self._active_tensor_filter = parsed.till_filter_pass 566 self._active_filter_exclude_node_names = ( 567 parsed.filter_exclude_node_names) 568 self._active_tensor_filter_run_start_response = run_start_response 569 else: 570 # Handle invalid filter name. 571 return debugger_cli_common.RichTextLines( 572 ["ERROR: tensor filter \"%s\" does not exist." % 573 parsed.till_filter_pass]) 574 575 # Raise CommandLineExit exception to cause the CLI to exit. 576 raise debugger_cli_common.CommandLineExit(exit_token=run_start_response) 577 578 def _register_this_run_info(self, curses_cli): 579 curses_cli.register_command_handler( 580 "run", 581 self._run_handler, 582 self._argparsers["run"].format_help(), 583 prefix_aliases=["r"]) 584 curses_cli.register_command_handler( 585 "run_info", 586 self._run_info_handler, 587 self._argparsers["run_info"].format_help(), 588 prefix_aliases=["ri"]) 589 curses_cli.register_command_handler( 590 "print_feed", 591 self._print_feed_handler, 592 self._argparsers["print_feed"].format_help(), 593 prefix_aliases=["pf"]) 594 595 if self._tensor_filters: 596 # Register tab completion for the filter names. 597 curses_cli.register_tab_comp_context(["run", "r"], 598 list(self._tensor_filters.keys())) 599 if self._feed_dict and hasattr(self._feed_dict, "keys"): 600 # Register tab completion for feed_dict keys. 601 feed_keys = [common.get_graph_element_name(key) 602 for key in self._feed_dict.keys()] 603 curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys) 604 605 def _get_run_debug_urls(self): 606 """Get the debug_urls value for the current run() call. 607 608 Returns: 609 debug_urls: (list of str) Debug URLs for the current run() call. 610 Currently, the list consists of only one URL that is a file:// URL. 611 """ 612 613 return ["file://" + self._dump_root] 614 615 def _update_run_calls_state(self, 616 run_call_count, 617 fetches, 618 feed_dict, 619 is_callable_runner=False): 620 """Update the internal state with regard to run() call history. 621 622 Args: 623 run_call_count: (int) Number of run() calls that have occurred. 624 fetches: a node/tensor or a list of node/tensor that are the fetches of 625 the run() call. This is the same as the fetches argument to the run() 626 call. 627 feed_dict: None of a dict. This is the feed_dict argument to the run() 628 call. 629 is_callable_runner: (bool) whether a runner returned by 630 Session.make_callable is being run. 631 """ 632 633 self._run_call_count = run_call_count 634 self._feed_dict = feed_dict 635 self._run_description = cli_shared.get_run_short_description( 636 run_call_count, 637 fetches, 638 feed_dict, 639 is_callable_runner=is_callable_runner) 640 self._run_through_times -= 1 641 642 self._run_info = cli_shared.get_run_start_intro( 643 run_call_count, 644 fetches, 645 feed_dict, 646 self._tensor_filters, 647 is_callable_runner=is_callable_runner) 648