1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Framework of debug wrapper sessions. 16 17A debug wrapper session is a wrapper around a TensorFlow Python Session. 18The wrapper preserves the Session interface, most importantly the run() method, 19while providing abilities to: 20a) Intercept a run() call to a wrapped session and insert debug tensor watches 21 according to externally-specified debug URLs. 22 23b) Release control to an external (i.e., non-Session) object before and after 24 the run() call, so that the external object can perform actions such as 25 launching a UI to let users inspect the intermediate tensors and partition 26 graphs from the run() call. 27 28c) (To be implemented in a future CL) Enter an instruction loop to let an 29 external object (e.g., remote client) launch run() and cont() calls 30 remotely. 31 32*** The lifetime of a debug wrapper session: *** 33 341) The wrapper session is created by calling the constructor with a 35 wrapped (normal) session as the argument: 36 wrapper = FooDebugWrapperSession(sess) 37 wherein FooDebugWrapperSession is a concrete subclass implementing the 38 abstract BaseDebugWrapperSession class below. 39 402) Near the end of the constructor call, the on_session_init() callback is 41 invoked, with a OnSessionInitRequest object as the argument. The object 42 carries the wrapped (normal) session object. 43 443) The callback handles the request and returns a OnSessionInitResponse 45 object with an action field, directing the wrapper session what to do next. 46 47If the action field in the OnSessionInitResponse is PROCEED, the constructor 48returns. Control is released back to the caller of the constructor, which can 49invoke run() method of wrapper session with the same syntax as a non-wrapped 50session, e.g.,: 51 wrapper.run(fetches, feed_dict=feeds, options=run_options) 52 53Below, A1 - A2 is the lifetime of a wrapper run() call if the action is 54PROCEED: 55 56A1) Right at the start of each run() call, the on_run_start() callback is 57 invoked, with an OnRunStartRequest object carrying information such as 58 the fetches, the feed dict, the run options and run metadata used in 59 this run call, along with a count of how many run calls has occurred 60 on this wrapper session. The callback then returns an OnRunStartResponse 61 object, of which the action field directs what the wrapper session 62 actually will do of the run() call. 63 64 If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue, 65 with the debug URLs supplied in the debug_urls field of the response. 66 These can be file:// or grpc:// URLs, for example. 67 68 If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue. 69 70A2) Right before the run() returns, the on_run_end() callback is invoked, 71 with an OnRunEndRequest object as the argument, which carries information 72 including the actual action performed in the wrapper run() call and the 73 run_metadata from the run() call. 74 75However, if the action field in OnSessionInitResponse is 76REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop 77that gives the control to a remote caller. 78 79In the remote instruction loop, the following steps will happen: 80 81B1) Callback on_instr_start() is invoked. The callback will return an 82 OnInstrStartResponse object with an action field which can order one of 83 the following actions: 84 i) a run() call with fetches, feeds and debug_urls specified. 85 ii) exit the instruction loop. 86 87B2) The wrapper session carries out the action specified above. 88 89B3) If still in the instruction loop, the wrapper session invokes the 90 on_instr_end() callback. After the on_instr_end() callback returns, jump 91 back to B1. 92 93TODO(cais): Implemented the instruction loop in B1 - B3. 94 95""" 96 97from __future__ import absolute_import 98from __future__ import division 99from __future__ import print_function 100 101import abc 102import re 103import threading 104 105import six 106 107from tensorflow.core.protobuf import config_pb2 108from tensorflow.python.client import session 109from tensorflow.python.debug.lib import debug_utils 110from tensorflow.python.framework import errors 111from tensorflow.python.framework import ops 112from tensorflow.python.platform import tf_logging 113from tensorflow.python.training import monitored_session 114from tensorflow.python.util import nest 115from tensorflow.python.util.compat import collections_abc 116 117 118# Helper function. 119def _check_type(obj, expected_types): 120 """Check if an object is of the expected type. 121 122 Args: 123 obj: The object being checked. 124 expected_types: (`type` or an iterable of `type`s) The expected `type`(s) 125 of obj. 126 127 Raises: 128 TypeError: If obj is not an instance of expected_type. 129 """ 130 if not isinstance(obj, expected_types): 131 raise TypeError("Expected type %s; got type %s" % 132 (expected_types, type(obj))) 133 134 135class OnSessionInitRequest(object): 136 """Request to an on-session-init callback. 137 138 This callback is invoked during the __init__ call to a debug-wrapper session. 139 """ 140 141 def __init__(self, sess): 142 """Constructor. 143 144 Args: 145 sess: A tensorflow Session object. 146 """ 147 148 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 149 self.session = sess 150 151 152class OnSessionInitAction(object): 153 """Enum-like values for possible action to take on session init.""" 154 155 # Proceed, without special actions, in the wrapper session initialization. 156 # What action the wrapper session performs next is determined by the caller 157 # of the wrapper session. E.g., it can call run(). 158 PROCEED = "proceed" 159 160 # Instead of letting the caller of the wrapper session determine what actions 161 # the wrapper session will perform next, enter a loop to receive instructions 162 # from a remote client. 163 # For example, TensorBoard visual debugger can use this action so that it can 164 # launch session.run() calls remotely. 165 REMOTE_INSTR_LOOP = "remote_instr_loop" 166 167 168class OnSessionInitResponse(object): 169 """Response from an on-session-init callback.""" 170 171 def __init__(self, action): 172 """Constructor. 173 174 Args: 175 action: (`OnSessionInitAction`) Debugger action to take on session init. 176 """ 177 _check_type(action, str) 178 self.action = action 179 180 181class OnRunStartRequest(object): 182 """Request to an on-run-start callback. 183 184 This callback is invoked during a run() call of the debug-wrapper 185 session, immediately after the run() call counter is incremented. 186 """ 187 188 def __init__(self, fetches, feed_dict, run_options, run_metadata, 189 run_call_count, is_callable_runner=False): 190 """Constructor of `OnRunStartRequest`. 191 192 Args: 193 fetches: Fetch targets of the run() call. 194 feed_dict: The feed dictionary to the run() call. 195 run_options: RunOptions input to the run() call. 196 run_metadata: RunMetadata input to the run() call. 197 The above four arguments are identical to the input arguments to the 198 run() method of a non-wrapped TensorFlow session. 199 run_call_count: 1-based count of how many run calls (including this one) 200 has been invoked. 201 is_callable_runner: (bool) whether a runner returned by 202 Session.make_callable is being run. 203 """ 204 self.fetches = fetches 205 self.feed_dict = feed_dict 206 self.run_options = run_options 207 self.run_metadata = run_metadata 208 self.run_call_count = run_call_count 209 self.is_callable_runner = is_callable_runner 210 211 212class OnRunStartAction(object): 213 """Enum-like values for possible action to take on start of a run() call.""" 214 215 # Run once with debug tensor-watching. 216 DEBUG_RUN = "debug_run" 217 218 # Run once with profiler. 219 PROFILE_RUN = "profile_run" 220 221 # Run without debug tensor-watching. 222 NON_DEBUG_RUN = "non_debug_run" 223 224 225 226class OnRunStartResponse(object): 227 """Request from an on-run-start callback. 228 229 The caller of the callback can use this response object to specify what 230 action the debug-wrapper session actually takes on the run() call. 231 """ 232 233 def __init__(self, 234 action, 235 debug_urls, 236 debug_ops="DebugIdentity", 237 node_name_regex_allowlist=None, 238 op_type_regex_allowlist=None, 239 tensor_dtype_regex_allowlist=None, 240 tolerate_debug_op_creation_failures=False): 241 """Constructor of `OnRunStartResponse`. 242 243 Args: 244 action: (`OnRunStartAction`) the action actually taken by the wrapped 245 session for the run() call. 246 debug_urls: (`list` of `str`) debug_urls used in watching the tensors 247 during the run() call. 248 debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the 249 debugger. 250 node_name_regex_allowlist: Regular-expression allowlist for node 251 name. 252 op_type_regex_allowlist: Regular-expression allowlist for op type. 253 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor 254 dtype. 255 tolerate_debug_op_creation_failures: Whether debug op creation failures 256 are to be tolerated. 257 """ 258 259 _check_type(action, str) 260 self.action = action 261 262 _check_type(debug_urls, list) 263 self.debug_urls = debug_urls 264 265 self.debug_ops = debug_ops 266 267 self.node_name_regex_allowlist = node_name_regex_allowlist 268 self.op_type_regex_allowlist = op_type_regex_allowlist 269 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist 270 self.tolerate_debug_op_creation_failures = ( 271 tolerate_debug_op_creation_failures) 272 273 274class OnRunEndRequest(object): 275 """Request to an on-run-end callback. 276 277 The callback is invoked immediately before the wrapped run() call ends. 278 """ 279 280 def __init__(self, 281 performed_action, 282 run_metadata=None, 283 client_graph_def=None, 284 tf_error=None): 285 """Constructor for `OnRunEndRequest`. 286 287 Args: 288 performed_action: (`OnRunStartAction`) Actually-performed action by the 289 debug-wrapper session. 290 run_metadata: run_metadata output from the run() call (if any). 291 client_graph_def: (GraphDef) GraphDef from the client side, i.e., from 292 the python front end of TensorFlow. Can be obtained with 293 session.graph.as_graph_def(). 294 tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred 295 during the run (if any). 296 """ 297 298 _check_type(performed_action, str) 299 self.performed_action = performed_action 300 301 if run_metadata is not None: 302 _check_type(run_metadata, config_pb2.RunMetadata) 303 self.run_metadata = run_metadata 304 self.client_graph_def = client_graph_def 305 self.tf_error = tf_error 306 307 308class OnRunEndResponse(object): 309 """Response from an on-run-end callback.""" 310 311 def __init__(self): 312 313 # Currently only a placeholder. 314 pass 315 316 317@six.add_metaclass(abc.ABCMeta) 318class BaseDebugWrapperSession(session.SessionInterface): 319 """Base class of debug-wrapper session classes. 320 321 Concrete classes that inherit from this class need to implement the abstract 322 methods such as on_session_init, on_run_start and on_run_end. 323 """ 324 325 def __init__(self, sess, thread_name_filter=None, 326 pass_through_operrors=False): 327 """Constructor of `BaseDebugWrapperSession`. 328 329 Args: 330 sess: An (unwrapped) TensorFlow session instance. It should be a subtype 331 of `BaseSession` or `tf.MonitoredSession`. 332 thread_name_filter: Regular-expression filter (allowlist) for name(s) of 333 thread(s) on which the wrapper session will be active. This regular 334 expression is used in a start-anchored fashion on the thread name, i.e., 335 by applying the `match` method of the compiled pattern. The default 336 `None` means that the wrapper session will be active on all threads. 337 E.g., r"MainThread$", r"QueueRunnerThread.*". 338 pass_through_operrors: If True, all captured OpErrors will be 339 propagated. By default this captures all OpErrors. 340 341 Raises: 342 ValueError: On invalid `OnSessionInitAction` value. 343 NotImplementedError: If a non-DirectSession sess object is received. 344 """ 345 346 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 347 348 # The session being wrapped. 349 self._sess = sess 350 self._thread_name_filter_pattern = (re.compile(thread_name_filter) 351 if thread_name_filter else None) 352 # TODO(cais/kstevens): Unittest this pass through feature. 353 self._pass_through_operrors = pass_through_operrors 354 355 # Keeps track of number of run calls that have been performed on this 356 # debug-wrapper session. The count can be used for purposes such as 357 # displaying the state of the Session in a UI and determining a run 358 # number-dependent debug URL. 359 self._run_call_count = 0 360 361 # Invoke on-session-init callback. 362 response = self.on_session_init(OnSessionInitRequest(self._sess)) 363 _check_type(response, OnSessionInitResponse) 364 365 if response.action == OnSessionInitAction.PROCEED: 366 pass 367 elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP: 368 # TODO(cais): Implement REMOTE_INSTR_LOOP 369 raise NotImplementedError( 370 "OnSessionInitAction REMOTE_INSTR_LOOP has not been " 371 "implemented.") 372 else: 373 raise ValueError( 374 "Invalid OnSessionInitAction value: %s" % response.action) 375 376 self._default_session_context_manager = None 377 378 # A cache for callables created from CallableOptions. 379 self._cached_callables_from_options = {} 380 381 @property 382 def graph(self): 383 return self._sess.graph 384 385 @property 386 def graph_def(self): 387 return self._sess.graph_def 388 389 @property 390 def sess_str(self): 391 return self._sess.sess_str 392 393 @property 394 def session(self): 395 return self._sess 396 397 def run(self, 398 fetches, 399 feed_dict=None, 400 options=None, 401 run_metadata=None, 402 callable_runner=None, 403 callable_runner_args=None, 404 callable_options=None): 405 """Wrapper around Session.run() that inserts tensor watch options. 406 407 Args: 408 fetches: Same as the `fetches` arg to regular `Session.run()`. 409 feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. 410 options: Same as the `options` arg to regular `Session.run()`. 411 run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. 412 callable_runner: A `callable` returned by `Session.make_callable()`. 413 If not `None`, `fetches` and `feed_dict` must both be `None`. 414 Mutually exclusive with `callable_options`. 415 callable_runner_args: An optional list of arguments to `callable_runner` 416 or for `callable_options`. 417 callable_options: An instance of `config_pb2.CallableOptions`, to be 418 used with `Session._make_callable_from_options()`. Mutually exclusive 419 with `callable_runner`. 420 421 Returns: 422 Simply forwards the output of the wrapped `Session.run()` call. 423 424 Raises: 425 ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` 426 is not `None` and either or both of `fetches` and `feed_dict` is `None`. 427 """ 428 if callable_runner and callable_options: 429 raise ValueError( 430 "callable_runner and callable_options are mutually exclusive, but " 431 "are both specified in this call to BaseDebugWrapperSession.run().") 432 433 if callable_runner and (fetches or feed_dict): 434 raise ValueError( 435 "callable_runner and fetches/feed_dict are mutually exclusive, " 436 "but are used simultaneously.") 437 elif callable_options and (fetches or feed_dict): 438 raise ValueError( 439 "callable_options and fetches/feed_dict are mutually exclusive, " 440 "but are used simultaneously.") 441 442 self.increment_run_call_count() 443 444 def is_empty(x): 445 """Check whether a possibly nested structure is empty.""" 446 if not nest.is_nested(x): 447 return False 448 if isinstance(x, collections_abc.Mapping): 449 return is_empty(list(x.values())) 450 for item in x: 451 if not is_empty(item): 452 return False 453 return True 454 455 empty_fetches = is_empty(fetches) 456 if empty_fetches: 457 tf_logging.info( 458 "Due to empty fetches, tfdbg Session wrapper is letting a " 459 "Session.run pass through without any debugging actions.") 460 if self._is_disabled_thread() or empty_fetches: 461 if callable_runner: 462 return callable_runner(*callable_runner_args) 463 elif callable_options: 464 # pylint:disable=protected-access 465 return self._sess._make_callable_from_options( 466 callable_options)(*callable_runner_args) 467 # pylint:enable=protected-access 468 else: 469 return self._sess.run(fetches, 470 feed_dict=feed_dict, 471 options=options, 472 run_metadata=run_metadata) 473 474 # Invoke on-run-start callback and obtain response. 475 run_start_resp = self.on_run_start( 476 OnRunStartRequest(fetches, feed_dict, options, run_metadata, 477 self._run_call_count, 478 is_callable_runner=bool(callable_runner))) 479 _check_type(run_start_resp, OnRunStartResponse) 480 481 if run_start_resp.action == OnRunStartAction.DEBUG_RUN: 482 retvals, run_end_req = self._run_with_debugging( 483 run_start_resp, fetches, feed_dict, options, run_metadata, 484 callable_runner, callable_runner_args, callable_options) 485 elif run_start_resp.action == OnRunStartAction.PROFILE_RUN: 486 retvals, run_end_req = self._run_with_profiling( 487 run_start_resp, fetches, feed_dict, options, run_metadata, 488 callable_runner, callable_runner_args, callable_options) 489 elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN: 490 # Invoke run() method of the wrapped session. 491 if callable_runner: 492 retvals = callable_runner(*callable_runner_args) 493 elif callable_options: 494 # pylint:disable=protected-access 495 callable_object = self._sess._make_callable_from_options( 496 callable_options) 497 # pylint:enable=protected-access 498 retvals = callable_object(*callable_runner_args) 499 else: 500 retvals = self._sess.run( 501 fetches, 502 feed_dict=feed_dict, 503 options=options, 504 run_metadata=run_metadata) 505 506 # Prepare arg for the on-run-end callback. 507 run_end_req = OnRunEndRequest(run_start_resp.action) 508 else: 509 raise ValueError( 510 "Invalid OnRunStartAction value: %s" % run_start_resp.action) 511 512 # Invoke on-run-end callback and obtain response. 513 run_end_resp = self.on_run_end(run_end_req) 514 _check_type(run_end_resp, OnRunEndResponse) 515 # Currently run_end_resp is only a placeholder. No action is taken on it. 516 517 return retvals 518 519 def _run_with_debugging(self, 520 run_start_resp, 521 fetches, 522 feed_dict, 523 options, 524 run_metadata, 525 callable_runner, 526 callable_runner_args, 527 callable_options): 528 """Perform a session.run() or callable with debugging.""" 529 # Decorate RunOption to fill in debugger tensor watch specifications. 530 decorated_run_options = None 531 if callable_options: 532 callable_options_id = id(callable_options) 533 if callable_options_id not in self._cached_callables_from_options: 534 # Make a copy of callable_options to avoid mutating it. 535 new_callable_options = config_pb2.CallableOptions() 536 new_callable_options.CopyFrom(callable_options) 537 decorated_run_options = new_callable_options.run_options 538 else: 539 decorated_run_options = options or config_pb2.RunOptions() 540 541 run_metadata = run_metadata or config_pb2.RunMetadata() 542 543 if decorated_run_options: 544 self._decorate_run_options_for_debug( 545 decorated_run_options, 546 run_start_resp.debug_urls, 547 debug_ops=run_start_resp.debug_ops, 548 node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist), 549 op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist, 550 tensor_dtype_regex_allowlist=( 551 run_start_resp.tensor_dtype_regex_allowlist), 552 tolerate_debug_op_creation_failures=( 553 run_start_resp.tolerate_debug_op_creation_failures)) 554 555 # Invoke the run() method of the wrapped Session. Catch any TensorFlow 556 # runtime errors. 557 tf_error = None 558 try: 559 if callable_runner: 560 retvals = callable_runner(*callable_runner_args, 561 options=decorated_run_options, 562 run_metadata=run_metadata) 563 elif callable_options: 564 # pylint:disable=protected-access 565 if callable_options_id in self._cached_callables_from_options: 566 callable_object = self._cached_callables_from_options[ 567 callable_options_id] 568 else: 569 callable_object = self._sess._make_callable_from_options( 570 new_callable_options) 571 self._cached_callables_from_options[ 572 callable_options_id] = callable_object 573 # pylint:enable=protected-access 574 retvals = callable_object( 575 *callable_runner_args, run_metadata=run_metadata) 576 else: 577 retvals = self._sess.run(fetches, 578 feed_dict=feed_dict, 579 options=decorated_run_options, 580 run_metadata=run_metadata) 581 except errors.OpError as op_error: 582 if self._pass_through_operrors: 583 raise op_error 584 tf_error = op_error 585 retvals = op_error 586 587 return retvals, OnRunEndRequest( 588 run_start_resp.action, 589 run_metadata=run_metadata, 590 client_graph_def=self._sess.graph.as_graph_def(), 591 tf_error=tf_error) 592 593 def _run_with_profiling(self, 594 run_start_resp, 595 fetches, 596 feed_dict, 597 options, 598 run_metadata, 599 callable_runner, 600 callable_runner_args, 601 callable_options): 602 """Perform a session.run() or callable with profiling.""" 603 # Decorate RunOption to fill in debugger tensor watch specifications. 604 decorated_run_options = None 605 if callable_options: 606 callable_options_id = id(callable_options) 607 if callable_options_id not in self._cached_callables_from_options: 608 # Make a copy of callable_options to avoid mutating it. 609 new_callable_options = config_pb2.CallableOptions() 610 new_callable_options.CopyFrom(callable_options) 611 decorated_run_options = new_callable_options.run_options 612 else: 613 decorated_run_options = options or config_pb2.RunOptions() 614 self._decorate_run_options_for_profile(decorated_run_options) 615 616 run_metadata = run_metadata or config_pb2.RunMetadata() 617 if callable_runner: 618 retvals = callable_runner(*callable_runner_args, 619 options=decorated_run_options, 620 run_metadata=run_metadata) 621 elif callable_options: 622 # pylint:disable=protected-access 623 callable_object = self._sess._make_callable_from_options( 624 new_callable_options) 625 # pylint:enable=protected-access 626 retvals = callable_object( 627 *callable_runner_args, run_metadata=run_metadata) 628 else: 629 retvals = self._sess.run(fetches, 630 feed_dict=feed_dict, 631 options=decorated_run_options, 632 run_metadata=run_metadata) 633 return retvals, OnRunEndRequest( 634 run_start_resp.action, 635 run_metadata=run_metadata, 636 client_graph_def=self._sess.graph.as_graph_def()) 637 638 def _is_disabled_thread(self): 639 thread_name = threading.current_thread().name or "" 640 return (self._thread_name_filter_pattern and 641 not self._thread_name_filter_pattern.match(thread_name)) 642 643 def run_step_fn(self, step_fn): 644 return step_fn( 645 monitored_session.MonitoredSession.StepContext(self._sess, self.run)) 646 647 def partial_run_setup(self, fetches, feeds=None): 648 """Sets up the feeds and fetches for partial runs in the session.""" 649 raise NotImplementedError( 650 "partial_run_setup is not implemented for debug-wrapper sessions.") 651 652 def partial_run(self, handle, fetches, feed_dict=None): 653 raise NotImplementedError( 654 "partial_run is not implemented for debug-wrapper sessions.") 655 656 def list_devices(self, *args, **kwargs): 657 return self._sess.list_devices(*args, **kwargs) 658 659 def reset(self, *args, **kwargs): 660 return self._sess.reset(*args, **kwargs) 661 662 def make_callable(self, 663 fetches, 664 feed_list=None, 665 accept_options=False): 666 runner = self._sess.make_callable( 667 fetches, feed_list=feed_list, accept_options=True) 668 def wrapped_runner(*runner_args, **kwargs): 669 return self.run(None, 670 feed_dict=None, 671 options=kwargs.get("options", None), 672 run_metadata=kwargs.get("run_metadata", None), 673 callable_runner=runner, 674 callable_runner_args=runner_args) 675 return wrapped_runner 676 677 def _make_callable_from_options(self, callable_options): 678 def wrapped_runner(*feed_values, **kwargs): 679 return self.run(None, 680 run_metadata=kwargs.get("run_metadata", None), 681 callable_options=callable_options, 682 callable_runner_args=feed_values) 683 return wrapped_runner 684 685 @property 686 def run_call_count(self): 687 return self._run_call_count 688 689 def increment_run_call_count(self): 690 self._run_call_count += 1 691 692 def _is_disk_usage_reset_each_run(self): 693 """Indicates whether disk usage is reset after each Session.run. 694 695 Subclasses that clean up the disk usage after every run should 696 override this protected method. 697 698 Returns: 699 (`bool`) Whether the disk usage amount is reset to zero after 700 each Session.run. 701 """ 702 return False 703 704 def _decorate_run_options_for_debug( 705 self, 706 run_options, 707 debug_urls, 708 debug_ops="DebugIdentity", 709 node_name_regex_allowlist=None, 710 op_type_regex_allowlist=None, 711 tensor_dtype_regex_allowlist=None, 712 tolerate_debug_op_creation_failures=False): 713 """Modify a RunOptions object for debug tensor watching. 714 715 Specifies request for outputting partition graphs. Adds 716 debug_tensor_watch_opts with proper debug URLs. 717 718 Args: 719 run_options: (RunOptions) the modified RunOptions object. 720 debug_urls: (list of str) debug URLs to be entered in run_options. 721 debug_tensor_watch_opts. 722 debug_ops: (str or list of str) debug op(s) to be used by the debugger. 723 node_name_regex_allowlist: Regular-expression allowlist for node 724 name. 725 op_type_regex_allowlist: Regular-expression allowlist for op type. 726 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor 727 dtype. 728 tolerate_debug_op_creation_failures: Whether debug op creation failures 729 are to be tolerated. 730 """ 731 732 run_options.output_partition_graphs = True 733 debug_utils.watch_graph( 734 run_options, 735 self._sess.graph, 736 debug_urls=debug_urls, 737 debug_ops=debug_ops, 738 node_name_regex_allowlist=node_name_regex_allowlist, 739 op_type_regex_allowlist=op_type_regex_allowlist, 740 tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist, 741 tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures, 742 reset_disk_byte_usage=(self._run_call_count == 1 or 743 self._is_disk_usage_reset_each_run())) 744 745 def _decorate_run_options_for_profile(self, run_options): 746 """Modify a RunOptions object for profiling TensorFlow graph execution. 747 748 Args: 749 run_options: (RunOptions) the modified RunOptions object. 750 """ 751 752 run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 753 754 @abc.abstractmethod 755 def on_session_init(self, request): 756 """Callback invoked during construction of the debug-wrapper session. 757 758 This is a blocking callback. 759 The invocation happens right before the constructor ends. 760 761 Args: 762 request: (`OnSessionInitRequest`) callback request carrying information 763 such as the session being wrapped. 764 765 Returns: 766 An instance of `OnSessionInitResponse`. 767 """ 768 769 @abc.abstractmethod 770 def on_run_start(self, request): 771 """Callback invoked on run() calls to the debug-wrapper session. 772 773 This is a blocking callback. 774 The invocation happens after the wrapper's run() call is entered, 775 after an increment of run call counter. 776 777 Args: 778 request: (`OnRunStartRequest`) callback request object carrying 779 information about the run call such as the fetches, feed dict, run 780 options, run metadata, and how many `run()` calls to this wrapper 781 session have occurred. 782 783 Returns: 784 An instance of `OnRunStartResponse`, carrying information to 785 debug URLs used to watch the tensors. 786 """ 787 788 @abc.abstractmethod 789 def on_run_end(self, request): 790 """Callback invoked on run() calls to the debug-wrapper session. 791 792 This is a blocking callback. 793 The invocation happens right before the wrapper exits its run() call. 794 795 Args: 796 request: (`OnRunEndRequest`) callback request object carrying information 797 such as the actual action performed by the session wrapper for the 798 run() call. 799 800 Returns: 801 An instance of `OnRunStartResponse`. 802 """ 803 804 def as_default(self): 805 return ops.default_session(self) 806 807 def __enter__(self): 808 if self._default_session_context_manager is None: 809 self._default_session_context_manager = self.as_default() 810 return self._default_session_context_manager.__enter__() 811 812 def __exit__(self, exec_type, exec_value, exec_tb): 813 self._default_session_context_manager.__exit__( 814 exec_type, exec_value, exec_tb) 815 816 def __del__(self): 817 if hasattr(self._sess, "__del__"): 818 self._sess.__del__() 819 820 def close(self): 821 self._sess.close() 822 823 # TODO(cais): Add _node_name_regex_allowlist and 824 # _node_op_type_regex_allowlist. 825 826 def should_stop(self): 827 if hasattr(self._sess, "should_stop"): 828 return self._sess.should_stop() 829 else: 830 raise ValueError( 831 "The wrapped session %r does not have a method called 'should_stop'. " 832 "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess) 833 834 835class WatchOptions(object): 836 """Type for return values of watch_fn.""" 837 838 def __init__(self, 839 debug_ops=None, 840 node_name_regex_allowlist=None, 841 op_type_regex_allowlist=None, 842 tensor_dtype_regex_allowlist=None, 843 tolerate_debug_op_creation_failures=False): 844 """Constructor of WatchOptions: Debug watch options. 845 846 Used as return values of `watch_fn`s. 847 848 Args: 849 debug_ops: (`str` or `list of str`) Debug ops to be used. 850 node_name_regex_allowlist: Regular-expression allowlist for node_name, 851 e.g., `"(weight_[0-9]+|bias_.*)"` 852 op_type_regex_allowlist: Regular-expression allowlist for the op type of 853 nodes, e.g., `"(Variable|Add)"`. 854 If both `node_name_regex_allowlist` and `op_type_regex_allowlist` 855 are set, the two filtering operations will occur in a logical `AND` 856 relation. In other words, a node will be included if and only if it 857 hits both allowlists. 858 tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor 859 data type, e.g., `"^int.*"`. 860 This allowlist operates in logical `AND` relations to the two allowlists 861 above. 862 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 863 failures (e.g., due to dtype incompatibility) are to be tolerated by not 864 throwing exceptions. 865 """ 866 if debug_ops: 867 self.debug_ops = debug_ops 868 else: 869 self.debug_ops = ["DebugIdentity"] 870 self.node_name_regex_allowlist = node_name_regex_allowlist 871 self.op_type_regex_allowlist = op_type_regex_allowlist 872 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist 873 self.tolerate_debug_op_creation_failures = ( 874 tolerate_debug_op_creation_failures) 875 876 def __repr__(self): 877 return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, " 878 "op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, " 879 "tolerate_debug_op_creation_failures=%r)" % 880 (self.debug_ops, self.node_name_regex_allowlist, 881 self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist, 882 self.tolerate_debug_op_creation_failures)) 883 884 885class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): 886 """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions.""" 887 888 def __init__(self, sess, watch_fn=None, thread_name_filter=None, 889 pass_through_operrors=False): 890 """Constructor of NonInteractiveDebugWrapperSession. 891 892 Args: 893 sess: The TensorFlow `Session` object being wrapped. 894 watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a 895 debugged `Session.run()` call to `WatchOptions.` 896 * Args: 897 * `fetches`: the fetches to the `Session.run()` call. 898 * `feeds`: the feeds to the `Session.run()` call. 899 900 * Returns: 901 (`tf_debug.WatchOptions`) An object containing debug options including 902 the debug ops to use, the node names, op types and/or tensor data 903 types to watch, etc. See the documentation of `tf_debug.WatchOptions` 904 for more details. 905 thread_name_filter: Regular-expression white list for threads on which the 906 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 907 more details. 908 pass_through_operrors: If true, all captured OpErrors will be 909 propagated. By default this captures all OpErrors. 910 Raises: 911 TypeError: If a non-None `watch_fn` is specified and it is not callable. 912 """ 913 914 BaseDebugWrapperSession.__init__( 915 self, sess, thread_name_filter=thread_name_filter, 916 pass_through_operrors=pass_through_operrors) 917 918 self._watch_fn = None 919 if watch_fn is not None: 920 if not callable(watch_fn): 921 raise TypeError("watch_fn is not callable") 922 self._watch_fn = watch_fn 923 924 def on_session_init(self, request): 925 """See doc of BaseDebugWrapperSession.on_run_start.""" 926 927 return OnSessionInitResponse(OnSessionInitAction.PROCEED) 928 929 @abc.abstractmethod 930 def prepare_run_debug_urls(self, fetches, feed_dict): 931 """Abstract method to be implemented by concrete subclasses. 932 933 This method prepares the run-specific debug URL(s). 934 935 Args: 936 fetches: Same as the `fetches` argument to `Session.run()` 937 feed_dict: Same as the `feed_dict` argument to `Session.run()` 938 939 Returns: 940 debug_urls: (`str` or `list` of `str`) Debug URLs to be used in 941 this `Session.run()` call. 942 """ 943 944 def on_run_start(self, request): 945 """See doc of BaseDebugWrapperSession.on_run_start.""" 946 947 debug_urls, watch_opts = self._prepare_run_watch_config( 948 request.fetches, request.feed_dict) 949 950 return OnRunStartResponse( 951 OnRunStartAction.DEBUG_RUN, 952 debug_urls, 953 debug_ops=watch_opts.debug_ops, 954 node_name_regex_allowlist=watch_opts.node_name_regex_allowlist, 955 op_type_regex_allowlist=watch_opts.op_type_regex_allowlist, 956 tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist, 957 tolerate_debug_op_creation_failures=( 958 watch_opts.tolerate_debug_op_creation_failures)) 959 960 def _prepare_run_watch_config(self, fetches, feed_dict): 961 """Get the debug_urls, and node/op allowlists for the current run() call. 962 963 Args: 964 fetches: Same as the `fetches` argument to `Session.run()`. 965 feed_dict: Same as the `feed_dict argument` to `Session.run()`. 966 967 Returns: 968 debug_urls: (str or list of str) Debug URLs for the current run() call. 969 Currently, the list consists of only one URL that is a file:// URL. 970 watch_options: (WatchOptions) The return value of a watch_fn, containing 971 options including debug_ops, and allowlists. 972 """ 973 974 debug_urls = self.prepare_run_debug_urls(fetches, feed_dict) 975 if self._watch_fn is None: 976 watch_options = WatchOptions() 977 else: 978 watch_options = self._watch_fn(fetches, feed_dict) 979 if isinstance(watch_options, tuple): 980 # For legacy return type (tuples). 981 watch_options = WatchOptions(*watch_options) 982 983 return debug_urls, watch_options 984 985 def on_run_end(self, request): 986 """See doc of BaseDebugWrapperSession.on_run_end.""" 987 988 return OnRunEndResponse() 989