• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Framework of debug wrapper sessions.
16
17A debug wrapper session is a wrapper around a TensorFlow Python Session.
18The wrapper preserves the Session interface, most importantly the run() method,
19while providing abilities to:
20a) Intercept a run() call to a wrapped session and insert debug tensor watches
21   according to externally-specified debug URLs.
22
23b) Release control to an external (i.e., non-Session) object before and after
24   the run() call, so that the external object can perform actions such as
25   launching a UI to let users inspect the intermediate tensors and partition
26   graphs from the run() call.
27
28c) (To be implemented in a future CL) Enter an instruction loop to let an
29   external object (e.g., remote client) launch run() and cont() calls
30   remotely.
31
32*** The lifetime of a debug wrapper session: ***
33
341) The wrapper session is created by calling the constructor with a
35   wrapped (normal) session as the argument:
36     wrapper = FooDebugWrapperSession(sess)
37   wherein FooDebugWrapperSession is a concrete subclass implementing the
38   abstract BaseDebugWrapperSession class below.
39
402) Near the end of the constructor call, the on_session_init() callback is
41   invoked, with a OnSessionInitRequest object as the argument. The object
42   carries the wrapped (normal) session object.
43
443) The callback handles the request and returns a OnSessionInitResponse
45   object with an action field, directing the wrapper session what to do next.
46
47If the action field in the OnSessionInitResponse is PROCEED, the constructor
48returns. Control is released back to the caller of the constructor, which can
49invoke run() method of wrapper session with the same syntax as a non-wrapped
50session, e.g.,:
51  wrapper.run(fetches, feed_dict=feeds, options=run_options)
52
53Below, A1 - A2 is the lifetime of a wrapper run() call if the action is
54PROCEED:
55
56A1) Right at the start of each run() call, the on_run_start() callback is
57    invoked, with an OnRunStartRequest object carrying information such as
58    the fetches, the feed dict, the run options and run metadata used in
59    this run call, along with a count of how many run calls has occurred
60    on this wrapper session. The callback then returns an OnRunStartResponse
61    object, of which the action field directs what the wrapper session
62    actually will do of the run() call.
63
64    If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue,
65    with the debug URLs supplied in the debug_urls field of the response.
66    These can be file:// or grpc:// URLs, for example.
67
68    If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue.
69
70A2) Right before the run() returns, the on_run_end() callback is invoked,
71    with an OnRunEndRequest object as the argument, which carries information
72    including the actual action performed in the wrapper run() call and the
73    run_metadata from the run() call.
74
75However, if the action field in OnSessionInitResponse is
76REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop
77that gives the control to a remote caller.
78
79In the remote instruction loop, the following steps will happen:
80
81B1) Callback on_instr_start() is invoked. The callback will return an
82    OnInstrStartResponse object with an action field which can order one of
83    the following actions:
84        i) a run() call with fetches, feeds and debug_urls specified.
85       ii) exit the instruction loop.
86
87B2) The wrapper session carries out the action specified above.
88
89B3) If still in the instruction loop, the wrapper session invokes the
90    on_instr_end() callback. After the on_instr_end() callback returns, jump
91    back to B1.
92
93TODO(cais): Implemented the instruction loop in B1 - B3.
94
95"""
96
97from __future__ import absolute_import
98from __future__ import division
99from __future__ import print_function
100
101import abc
102import re
103import threading
104
105import six
106
107from tensorflow.core.protobuf import config_pb2
108from tensorflow.python.client import session
109from tensorflow.python.debug.lib import debug_utils
110from tensorflow.python.framework import errors
111from tensorflow.python.framework import ops
112from tensorflow.python.platform import tf_logging
113from tensorflow.python.training import monitored_session
114from tensorflow.python.util import nest
115from tensorflow.python.util.compat import collections_abc
116
117
118# Helper function.
119def _check_type(obj, expected_types):
120  """Check if an object is of the expected type.
121
122  Args:
123    obj: The object being checked.
124    expected_types: (`type` or an iterable of `type`s) The expected `type`(s)
125      of obj.
126
127  Raises:
128      TypeError: If obj is not an instance of expected_type.
129  """
130  if not isinstance(obj, expected_types):
131    raise TypeError("Expected type %s; got type %s" %
132                    (expected_types, type(obj)))
133
134
135class OnSessionInitRequest(object):
136  """Request to an on-session-init callback.
137
138  This callback is invoked during the __init__ call to a debug-wrapper session.
139  """
140
141  def __init__(self, sess):
142    """Constructor.
143
144    Args:
145      sess: A tensorflow Session object.
146    """
147
148    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
149    self.session = sess
150
151
152class OnSessionInitAction(object):
153  """Enum-like values for possible action to take on session init."""
154
155  # Proceed, without special actions, in the wrapper session initialization.
156  # What action the wrapper session performs next is determined by the caller
157  # of the wrapper session. E.g., it can call run().
158  PROCEED = "proceed"
159
160  # Instead of letting the caller of the wrapper session determine what actions
161  # the wrapper session will perform next, enter a loop to receive instructions
162  # from a remote client.
163  # For example, TensorBoard visual debugger can use this action so that it can
164  # launch session.run() calls remotely.
165  REMOTE_INSTR_LOOP = "remote_instr_loop"
166
167
168class OnSessionInitResponse(object):
169  """Response from an on-session-init callback."""
170
171  def __init__(self, action):
172    """Constructor.
173
174    Args:
175      action: (`OnSessionInitAction`) Debugger action to take on session init.
176    """
177    _check_type(action, str)
178    self.action = action
179
180
181class OnRunStartRequest(object):
182  """Request to an on-run-start callback.
183
184  This callback is invoked during a run() call of the debug-wrapper
185  session, immediately after the run() call counter is incremented.
186  """
187
188  def __init__(self, fetches, feed_dict, run_options, run_metadata,
189               run_call_count, is_callable_runner=False):
190    """Constructor of `OnRunStartRequest`.
191
192    Args:
193      fetches: Fetch targets of the run() call.
194      feed_dict: The feed dictionary to the run() call.
195      run_options: RunOptions input to the run() call.
196      run_metadata: RunMetadata input to the run() call.
197        The above four arguments are identical to the input arguments to the
198        run() method of a non-wrapped TensorFlow session.
199      run_call_count: 1-based count of how many run calls (including this one)
200        has been invoked.
201      is_callable_runner: (bool) whether a runner returned by
202        Session.make_callable is being run.
203    """
204    self.fetches = fetches
205    self.feed_dict = feed_dict
206    self.run_options = run_options
207    self.run_metadata = run_metadata
208    self.run_call_count = run_call_count
209    self.is_callable_runner = is_callable_runner
210
211
212class OnRunStartAction(object):
213  """Enum-like values for possible action to take on start of a run() call."""
214
215  # Run once with debug tensor-watching.
216  DEBUG_RUN = "debug_run"
217
218  # Run once with profiler.
219  PROFILE_RUN = "profile_run"
220
221  # Run without debug tensor-watching.
222  NON_DEBUG_RUN = "non_debug_run"
223
224
225
226class OnRunStartResponse(object):
227  """Request from an on-run-start callback.
228
229  The caller of the callback can use this response object to specify what
230  action the debug-wrapper session actually takes on the run() call.
231  """
232
233  def __init__(self,
234               action,
235               debug_urls,
236               debug_ops="DebugIdentity",
237               node_name_regex_allowlist=None,
238               op_type_regex_allowlist=None,
239               tensor_dtype_regex_allowlist=None,
240               tolerate_debug_op_creation_failures=False):
241    """Constructor of `OnRunStartResponse`.
242
243    Args:
244      action: (`OnRunStartAction`) the action actually taken by the wrapped
245        session for the run() call.
246      debug_urls: (`list` of `str`) debug_urls used in watching the tensors
247        during the run() call.
248      debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
249        debugger.
250      node_name_regex_allowlist: Regular-expression allowlist for node
251        name.
252      op_type_regex_allowlist: Regular-expression allowlist for op type.
253      tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
254        dtype.
255      tolerate_debug_op_creation_failures: Whether debug op creation failures
256        are to be tolerated.
257    """
258
259    _check_type(action, str)
260    self.action = action
261
262    _check_type(debug_urls, list)
263    self.debug_urls = debug_urls
264
265    self.debug_ops = debug_ops
266
267    self.node_name_regex_allowlist = node_name_regex_allowlist
268    self.op_type_regex_allowlist = op_type_regex_allowlist
269    self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
270    self.tolerate_debug_op_creation_failures = (
271        tolerate_debug_op_creation_failures)
272
273
274class OnRunEndRequest(object):
275  """Request to an on-run-end callback.
276
277  The callback is invoked immediately before the wrapped run() call ends.
278  """
279
280  def __init__(self,
281               performed_action,
282               run_metadata=None,
283               client_graph_def=None,
284               tf_error=None):
285    """Constructor for `OnRunEndRequest`.
286
287    Args:
288      performed_action: (`OnRunStartAction`) Actually-performed action by the
289        debug-wrapper session.
290      run_metadata: run_metadata output from the run() call (if any).
291      client_graph_def: (GraphDef) GraphDef from the client side, i.e., from
292        the python front end of TensorFlow. Can be obtained with
293        session.graph.as_graph_def().
294      tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred
295        during the run (if any).
296    """
297
298    _check_type(performed_action, str)
299    self.performed_action = performed_action
300
301    if run_metadata is not None:
302      _check_type(run_metadata, config_pb2.RunMetadata)
303    self.run_metadata = run_metadata
304    self.client_graph_def = client_graph_def
305    self.tf_error = tf_error
306
307
308class OnRunEndResponse(object):
309  """Response from an on-run-end callback."""
310
311  def __init__(self):
312
313    # Currently only a placeholder.
314    pass
315
316
317@six.add_metaclass(abc.ABCMeta)
318class BaseDebugWrapperSession(session.SessionInterface):
319  """Base class of debug-wrapper session classes.
320
321  Concrete classes that inherit from this class need to implement the abstract
322  methods such as on_session_init, on_run_start and on_run_end.
323  """
324
325  def __init__(self, sess, thread_name_filter=None,
326               pass_through_operrors=False):
327    """Constructor of `BaseDebugWrapperSession`.
328
329    Args:
330      sess: An (unwrapped) TensorFlow session instance. It should be a subtype
331        of `BaseSession` or `tf.MonitoredSession`.
332      thread_name_filter: Regular-expression filter (allowlist) for name(s) of
333        thread(s) on which the wrapper session will be active. This regular
334        expression is used in a start-anchored fashion on the thread name, i.e.,
335        by applying the `match` method of the compiled pattern. The default
336        `None` means that the wrapper session will be active on all threads.
337        E.g., r"MainThread$", r"QueueRunnerThread.*".
338      pass_through_operrors: If True, all captured OpErrors will be
339        propagated.  By default this captures all OpErrors.
340
341    Raises:
342      ValueError: On invalid `OnSessionInitAction` value.
343      NotImplementedError: If a non-DirectSession sess object is received.
344    """
345
346    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
347
348    # The session being wrapped.
349    self._sess = sess
350    self._thread_name_filter_pattern = (re.compile(thread_name_filter)
351                                        if thread_name_filter else None)
352    # TODO(cais/kstevens): Unittest this pass through feature.
353    self._pass_through_operrors = pass_through_operrors
354
355    # Keeps track of number of run calls that have been performed on this
356    # debug-wrapper session. The count can be used for purposes such as
357    # displaying the state of the Session in a UI and determining a run
358    # number-dependent debug URL.
359    self._run_call_count = 0
360
361    # Invoke on-session-init callback.
362    response = self.on_session_init(OnSessionInitRequest(self._sess))
363    _check_type(response, OnSessionInitResponse)
364
365    if response.action == OnSessionInitAction.PROCEED:
366      pass
367    elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP:
368      # TODO(cais): Implement REMOTE_INSTR_LOOP
369      raise NotImplementedError(
370          "OnSessionInitAction REMOTE_INSTR_LOOP has not been "
371          "implemented.")
372    else:
373      raise ValueError(
374          "Invalid OnSessionInitAction value: %s" % response.action)
375
376    self._default_session_context_manager = None
377
378    # A cache for callables created from CallableOptions.
379    self._cached_callables_from_options = {}
380
381  @property
382  def graph(self):
383    return self._sess.graph
384
385  @property
386  def graph_def(self):
387    return self._sess.graph_def
388
389  @property
390  def sess_str(self):
391    return self._sess.sess_str
392
393  @property
394  def session(self):
395    return self._sess
396
397  def run(self,
398          fetches,
399          feed_dict=None,
400          options=None,
401          run_metadata=None,
402          callable_runner=None,
403          callable_runner_args=None,
404          callable_options=None):
405    """Wrapper around Session.run() that inserts tensor watch options.
406
407    Args:
408      fetches: Same as the `fetches` arg to regular `Session.run()`.
409      feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
410      options: Same as the `options` arg to regular `Session.run()`.
411      run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
412      callable_runner: A `callable` returned by `Session.make_callable()`.
413        If not `None`, `fetches` and `feed_dict` must both be `None`.
414        Mutually exclusive with `callable_options`.
415      callable_runner_args: An optional list of arguments to `callable_runner`
416        or for `callable_options`.
417      callable_options: An instance of `config_pb2.CallableOptions`, to be
418        used with `Session._make_callable_from_options()`. Mutually exclusive
419        with `callable_runner`.
420
421    Returns:
422      Simply forwards the output of the wrapped `Session.run()` call.
423
424    Raises:
425      ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
426        is not `None` and either or both of `fetches` and `feed_dict` is `None`.
427    """
428    if callable_runner and callable_options:
429      raise ValueError(
430          "callable_runner and callable_options are mutually exclusive, but "
431          "are both specified in this call to BaseDebugWrapperSession.run().")
432
433    if callable_runner and (fetches or feed_dict):
434      raise ValueError(
435          "callable_runner and fetches/feed_dict are mutually exclusive, "
436          "but are used simultaneously.")
437    elif callable_options and (fetches or feed_dict):
438      raise ValueError(
439          "callable_options and fetches/feed_dict are mutually exclusive, "
440          "but are used simultaneously.")
441
442    self.increment_run_call_count()
443
444    def is_empty(x):
445      """Check whether a possibly nested structure is empty."""
446      if not nest.is_nested(x):
447        return False
448      if isinstance(x, collections_abc.Mapping):
449        return is_empty(list(x.values()))
450      for item in x:
451        if not is_empty(item):
452          return False
453      return True
454
455    empty_fetches = is_empty(fetches)
456    if empty_fetches:
457      tf_logging.info(
458          "Due to empty fetches, tfdbg Session wrapper is letting a "
459          "Session.run pass through without any debugging actions.")
460    if self._is_disabled_thread() or empty_fetches:
461      if callable_runner:
462        return callable_runner(*callable_runner_args)
463      elif callable_options:
464        # pylint:disable=protected-access
465        return self._sess._make_callable_from_options(
466            callable_options)(*callable_runner_args)
467        # pylint:enable=protected-access
468      else:
469        return self._sess.run(fetches,
470                              feed_dict=feed_dict,
471                              options=options,
472                              run_metadata=run_metadata)
473
474    # Invoke on-run-start callback and obtain response.
475    run_start_resp = self.on_run_start(
476        OnRunStartRequest(fetches, feed_dict, options, run_metadata,
477                          self._run_call_count,
478                          is_callable_runner=bool(callable_runner)))
479    _check_type(run_start_resp, OnRunStartResponse)
480
481    if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
482      retvals, run_end_req = self._run_with_debugging(
483          run_start_resp, fetches, feed_dict, options, run_metadata,
484          callable_runner, callable_runner_args, callable_options)
485    elif run_start_resp.action == OnRunStartAction.PROFILE_RUN:
486      retvals, run_end_req = self._run_with_profiling(
487          run_start_resp, fetches, feed_dict, options, run_metadata,
488          callable_runner, callable_runner_args, callable_options)
489    elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN:
490      # Invoke run() method of the wrapped session.
491      if callable_runner:
492        retvals = callable_runner(*callable_runner_args)
493      elif callable_options:
494        # pylint:disable=protected-access
495        callable_object = self._sess._make_callable_from_options(
496            callable_options)
497        # pylint:enable=protected-access
498        retvals = callable_object(*callable_runner_args)
499      else:
500        retvals = self._sess.run(
501            fetches,
502            feed_dict=feed_dict,
503            options=options,
504            run_metadata=run_metadata)
505
506      # Prepare arg for the on-run-end callback.
507      run_end_req = OnRunEndRequest(run_start_resp.action)
508    else:
509      raise ValueError(
510          "Invalid OnRunStartAction value: %s" % run_start_resp.action)
511
512    # Invoke on-run-end callback and obtain response.
513    run_end_resp = self.on_run_end(run_end_req)
514    _check_type(run_end_resp, OnRunEndResponse)
515    # Currently run_end_resp is only a placeholder. No action is taken on it.
516
517    return retvals
518
519  def _run_with_debugging(self,
520                          run_start_resp,
521                          fetches,
522                          feed_dict,
523                          options,
524                          run_metadata,
525                          callable_runner,
526                          callable_runner_args,
527                          callable_options):
528    """Perform a session.run() or callable with debugging."""
529    # Decorate RunOption to fill in debugger tensor watch specifications.
530    decorated_run_options = None
531    if callable_options:
532      callable_options_id = id(callable_options)
533      if callable_options_id not in self._cached_callables_from_options:
534        # Make a copy of callable_options to avoid mutating it.
535        new_callable_options = config_pb2.CallableOptions()
536        new_callable_options.CopyFrom(callable_options)
537        decorated_run_options = new_callable_options.run_options
538    else:
539      decorated_run_options = options or config_pb2.RunOptions()
540
541    run_metadata = run_metadata or config_pb2.RunMetadata()
542
543    if decorated_run_options:
544      self._decorate_run_options_for_debug(
545          decorated_run_options,
546          run_start_resp.debug_urls,
547          debug_ops=run_start_resp.debug_ops,
548          node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist),
549          op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist,
550          tensor_dtype_regex_allowlist=(
551              run_start_resp.tensor_dtype_regex_allowlist),
552          tolerate_debug_op_creation_failures=(
553              run_start_resp.tolerate_debug_op_creation_failures))
554
555    # Invoke the run() method of the wrapped Session. Catch any TensorFlow
556    # runtime errors.
557    tf_error = None
558    try:
559      if callable_runner:
560        retvals = callable_runner(*callable_runner_args,
561                                  options=decorated_run_options,
562                                  run_metadata=run_metadata)
563      elif callable_options:
564        # pylint:disable=protected-access
565        if callable_options_id in self._cached_callables_from_options:
566          callable_object = self._cached_callables_from_options[
567              callable_options_id]
568        else:
569          callable_object = self._sess._make_callable_from_options(
570              new_callable_options)
571          self._cached_callables_from_options[
572              callable_options_id] = callable_object
573        # pylint:enable=protected-access
574        retvals = callable_object(
575            *callable_runner_args, run_metadata=run_metadata)
576      else:
577        retvals = self._sess.run(fetches,
578                                 feed_dict=feed_dict,
579                                 options=decorated_run_options,
580                                 run_metadata=run_metadata)
581    except errors.OpError as op_error:
582      if self._pass_through_operrors:
583        raise op_error
584      tf_error = op_error
585      retvals = op_error
586
587    return retvals, OnRunEndRequest(
588        run_start_resp.action,
589        run_metadata=run_metadata,
590        client_graph_def=self._sess.graph.as_graph_def(),
591        tf_error=tf_error)
592
593  def _run_with_profiling(self,
594                          run_start_resp,
595                          fetches,
596                          feed_dict,
597                          options,
598                          run_metadata,
599                          callable_runner,
600                          callable_runner_args,
601                          callable_options):
602    """Perform a session.run() or callable with profiling."""
603    # Decorate RunOption to fill in debugger tensor watch specifications.
604    decorated_run_options = None
605    if callable_options:
606      callable_options_id = id(callable_options)
607      if callable_options_id not in self._cached_callables_from_options:
608        # Make a copy of callable_options to avoid mutating it.
609        new_callable_options = config_pb2.CallableOptions()
610        new_callable_options.CopyFrom(callable_options)
611        decorated_run_options = new_callable_options.run_options
612    else:
613      decorated_run_options = options or config_pb2.RunOptions()
614    self._decorate_run_options_for_profile(decorated_run_options)
615
616    run_metadata = run_metadata or config_pb2.RunMetadata()
617    if callable_runner:
618      retvals = callable_runner(*callable_runner_args,
619                                options=decorated_run_options,
620                                run_metadata=run_metadata)
621    elif callable_options:
622      # pylint:disable=protected-access
623      callable_object = self._sess._make_callable_from_options(
624          new_callable_options)
625      # pylint:enable=protected-access
626      retvals = callable_object(
627          *callable_runner_args, run_metadata=run_metadata)
628    else:
629      retvals = self._sess.run(fetches,
630                               feed_dict=feed_dict,
631                               options=decorated_run_options,
632                               run_metadata=run_metadata)
633    return retvals, OnRunEndRequest(
634        run_start_resp.action,
635        run_metadata=run_metadata,
636        client_graph_def=self._sess.graph.as_graph_def())
637
638  def _is_disabled_thread(self):
639    thread_name = threading.current_thread().name or ""
640    return (self._thread_name_filter_pattern and
641            not self._thread_name_filter_pattern.match(thread_name))
642
643  def run_step_fn(self, step_fn):
644    return step_fn(
645        monitored_session.MonitoredSession.StepContext(self._sess, self.run))
646
647  def partial_run_setup(self, fetches, feeds=None):
648    """Sets up the feeds and fetches for partial runs in the session."""
649    raise NotImplementedError(
650        "partial_run_setup is not implemented for debug-wrapper sessions.")
651
652  def partial_run(self, handle, fetches, feed_dict=None):
653    raise NotImplementedError(
654        "partial_run is not implemented for debug-wrapper sessions.")
655
656  def list_devices(self, *args, **kwargs):
657    return self._sess.list_devices(*args, **kwargs)
658
659  def reset(self, *args, **kwargs):
660    return self._sess.reset(*args, **kwargs)
661
662  def make_callable(self,
663                    fetches,
664                    feed_list=None,
665                    accept_options=False):
666    runner = self._sess.make_callable(
667        fetches, feed_list=feed_list, accept_options=True)
668    def wrapped_runner(*runner_args, **kwargs):
669      return self.run(None,
670                      feed_dict=None,
671                      options=kwargs.get("options", None),
672                      run_metadata=kwargs.get("run_metadata", None),
673                      callable_runner=runner,
674                      callable_runner_args=runner_args)
675    return wrapped_runner
676
677  def _make_callable_from_options(self, callable_options):
678    def wrapped_runner(*feed_values, **kwargs):
679      return self.run(None,
680                      run_metadata=kwargs.get("run_metadata", None),
681                      callable_options=callable_options,
682                      callable_runner_args=feed_values)
683    return wrapped_runner
684
685  @property
686  def run_call_count(self):
687    return self._run_call_count
688
689  def increment_run_call_count(self):
690    self._run_call_count += 1
691
692  def _is_disk_usage_reset_each_run(self):
693    """Indicates whether disk usage is reset after each Session.run.
694
695    Subclasses that clean up the disk usage after every run should
696    override this protected method.
697
698    Returns:
699      (`bool`) Whether the disk usage amount is reset to zero after
700        each Session.run.
701    """
702    return False
703
704  def _decorate_run_options_for_debug(
705      self,
706      run_options,
707      debug_urls,
708      debug_ops="DebugIdentity",
709      node_name_regex_allowlist=None,
710      op_type_regex_allowlist=None,
711      tensor_dtype_regex_allowlist=None,
712      tolerate_debug_op_creation_failures=False):
713    """Modify a RunOptions object for debug tensor watching.
714
715    Specifies request for outputting partition graphs. Adds
716    debug_tensor_watch_opts with proper debug URLs.
717
718    Args:
719      run_options: (RunOptions) the modified RunOptions object.
720      debug_urls: (list of str) debug URLs to be entered in run_options.
721        debug_tensor_watch_opts.
722      debug_ops: (str or list of str) debug op(s) to be used by the debugger.
723      node_name_regex_allowlist: Regular-expression allowlist for node
724        name.
725      op_type_regex_allowlist: Regular-expression allowlist for op type.
726      tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
727        dtype.
728      tolerate_debug_op_creation_failures: Whether debug op creation failures
729        are to be tolerated.
730    """
731
732    run_options.output_partition_graphs = True
733    debug_utils.watch_graph(
734        run_options,
735        self._sess.graph,
736        debug_urls=debug_urls,
737        debug_ops=debug_ops,
738        node_name_regex_allowlist=node_name_regex_allowlist,
739        op_type_regex_allowlist=op_type_regex_allowlist,
740        tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist,
741        tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
742        reset_disk_byte_usage=(self._run_call_count == 1 or
743                               self._is_disk_usage_reset_each_run()))
744
745  def _decorate_run_options_for_profile(self, run_options):
746    """Modify a RunOptions object for profiling TensorFlow graph execution.
747
748    Args:
749      run_options: (RunOptions) the modified RunOptions object.
750    """
751
752    run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
753
754  @abc.abstractmethod
755  def on_session_init(self, request):
756    """Callback invoked during construction of the debug-wrapper session.
757
758    This is a blocking callback.
759    The invocation happens right before the constructor ends.
760
761    Args:
762      request: (`OnSessionInitRequest`) callback request carrying information
763        such as the session being wrapped.
764
765    Returns:
766      An instance of `OnSessionInitResponse`.
767    """
768
769  @abc.abstractmethod
770  def on_run_start(self, request):
771    """Callback invoked on run() calls to the debug-wrapper session.
772
773    This is a blocking callback.
774    The invocation happens after the wrapper's run() call is entered,
775    after an increment of run call counter.
776
777    Args:
778      request: (`OnRunStartRequest`) callback request object carrying
779        information about the run call such as the fetches, feed dict, run
780        options, run metadata, and how many `run()` calls to this wrapper
781        session have occurred.
782
783    Returns:
784      An instance of `OnRunStartResponse`, carrying information to
785        debug URLs used to watch the tensors.
786    """
787
788  @abc.abstractmethod
789  def on_run_end(self, request):
790    """Callback invoked on run() calls to the debug-wrapper session.
791
792    This is a blocking callback.
793    The invocation happens right before the wrapper exits its run() call.
794
795    Args:
796      request: (`OnRunEndRequest`) callback request object carrying information
797        such as the actual action performed by the session wrapper for the
798        run() call.
799
800    Returns:
801      An instance of `OnRunStartResponse`.
802    """
803
804  def as_default(self):
805    return ops.default_session(self)
806
807  def __enter__(self):
808    if self._default_session_context_manager is None:
809      self._default_session_context_manager = self.as_default()
810    return self._default_session_context_manager.__enter__()
811
812  def __exit__(self, exec_type, exec_value, exec_tb):
813    self._default_session_context_manager.__exit__(
814        exec_type, exec_value, exec_tb)
815
816  def __del__(self):
817    if hasattr(self._sess, "__del__"):
818      self._sess.__del__()
819
820  def close(self):
821    self._sess.close()
822
823  # TODO(cais): Add _node_name_regex_allowlist and
824  #   _node_op_type_regex_allowlist.
825
826  def should_stop(self):
827    if hasattr(self._sess, "should_stop"):
828      return self._sess.should_stop()
829    else:
830      raise ValueError(
831          "The wrapped session %r does not have a method called 'should_stop'. "
832          "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess)
833
834
835class WatchOptions(object):
836  """Type for return values of watch_fn."""
837
838  def __init__(self,
839               debug_ops=None,
840               node_name_regex_allowlist=None,
841               op_type_regex_allowlist=None,
842               tensor_dtype_regex_allowlist=None,
843               tolerate_debug_op_creation_failures=False):
844    """Constructor of WatchOptions: Debug watch options.
845
846    Used as return values of `watch_fn`s.
847
848    Args:
849      debug_ops: (`str` or `list of str`) Debug ops to be used.
850      node_name_regex_allowlist: Regular-expression allowlist for node_name,
851        e.g., `"(weight_[0-9]+|bias_.*)"`
852      op_type_regex_allowlist: Regular-expression allowlist for the op type of
853        nodes, e.g., `"(Variable|Add)"`.
854        If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
855        are set, the two filtering operations will occur in a logical `AND`
856        relation. In other words, a node will be included if and only if it
857        hits both allowlists.
858      tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
859        data type, e.g., `"^int.*"`.
860        This allowlist operates in logical `AND` relations to the two allowlists
861        above.
862      tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
863        failures (e.g., due to dtype incompatibility) are to be tolerated by not
864        throwing exceptions.
865    """
866    if debug_ops:
867      self.debug_ops = debug_ops
868    else:
869      self.debug_ops = ["DebugIdentity"]
870    self.node_name_regex_allowlist = node_name_regex_allowlist
871    self.op_type_regex_allowlist = op_type_regex_allowlist
872    self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
873    self.tolerate_debug_op_creation_failures = (
874        tolerate_debug_op_creation_failures)
875
876  def __repr__(self):
877    return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, "
878            "op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, "
879            "tolerate_debug_op_creation_failures=%r)" %
880            (self.debug_ops, self.node_name_regex_allowlist,
881             self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist,
882             self.tolerate_debug_op_creation_failures))
883
884
885class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
886  """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions."""
887
888  def __init__(self, sess, watch_fn=None, thread_name_filter=None,
889               pass_through_operrors=False):
890    """Constructor of NonInteractiveDebugWrapperSession.
891
892    Args:
893      sess: The TensorFlow `Session` object being wrapped.
894      watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a
895        debugged `Session.run()` call to `WatchOptions.`
896        * Args:
897          * `fetches`: the fetches to the `Session.run()` call.
898          * `feeds`: the feeds to the `Session.run()` call.
899
900        * Returns:
901         (`tf_debug.WatchOptions`) An object containing debug options including
902           the debug ops to use, the node names, op types and/or tensor data
903           types to watch, etc. See the documentation of `tf_debug.WatchOptions`
904           for more details.
905      thread_name_filter: Regular-expression white list for threads on which the
906        wrapper session will be active. See doc of `BaseDebugWrapperSession` for
907        more details.
908      pass_through_operrors: If true, all captured OpErrors will be
909        propagated.  By default this captures all OpErrors.
910    Raises:
911       TypeError: If a non-None `watch_fn` is specified and it is not callable.
912    """
913
914    BaseDebugWrapperSession.__init__(
915        self, sess, thread_name_filter=thread_name_filter,
916        pass_through_operrors=pass_through_operrors)
917
918    self._watch_fn = None
919    if watch_fn is not None:
920      if not callable(watch_fn):
921        raise TypeError("watch_fn is not callable")
922      self._watch_fn = watch_fn
923
924  def on_session_init(self, request):
925    """See doc of BaseDebugWrapperSession.on_run_start."""
926
927    return OnSessionInitResponse(OnSessionInitAction.PROCEED)
928
929  @abc.abstractmethod
930  def prepare_run_debug_urls(self, fetches, feed_dict):
931    """Abstract method to be implemented by concrete subclasses.
932
933    This method prepares the run-specific debug URL(s).
934
935    Args:
936      fetches: Same as the `fetches` argument to `Session.run()`
937      feed_dict: Same as the `feed_dict` argument to `Session.run()`
938
939    Returns:
940      debug_urls: (`str` or `list` of `str`) Debug URLs to be used in
941        this `Session.run()` call.
942    """
943
944  def on_run_start(self, request):
945    """See doc of BaseDebugWrapperSession.on_run_start."""
946
947    debug_urls, watch_opts = self._prepare_run_watch_config(
948        request.fetches, request.feed_dict)
949
950    return OnRunStartResponse(
951        OnRunStartAction.DEBUG_RUN,
952        debug_urls,
953        debug_ops=watch_opts.debug_ops,
954        node_name_regex_allowlist=watch_opts.node_name_regex_allowlist,
955        op_type_regex_allowlist=watch_opts.op_type_regex_allowlist,
956        tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist,
957        tolerate_debug_op_creation_failures=(
958            watch_opts.tolerate_debug_op_creation_failures))
959
960  def _prepare_run_watch_config(self, fetches, feed_dict):
961    """Get the debug_urls, and node/op allowlists for the current run() call.
962
963    Args:
964      fetches: Same as the `fetches` argument to `Session.run()`.
965      feed_dict: Same as the `feed_dict argument` to `Session.run()`.
966
967    Returns:
968      debug_urls: (str or list of str) Debug URLs for the current run() call.
969        Currently, the list consists of only one URL that is a file:// URL.
970      watch_options: (WatchOptions) The return value of a watch_fn, containing
971        options including debug_ops, and allowlists.
972    """
973
974    debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
975    if self._watch_fn is None:
976      watch_options = WatchOptions()
977    else:
978      watch_options = self._watch_fn(fetches, feed_dict)
979      if isinstance(watch_options, tuple):
980        # For legacy return type (tuples).
981        watch_options = WatchOptions(*watch_options)
982
983    return debug_urls, watch_options
984
985  def on_run_end(self, request):
986    """See doc of BaseDebugWrapperSession.on_run_end."""
987
988    return OnRunEndResponse()
989