• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tfdbg CLI as SessionRunHook."""
16
17from tensorflow.core.protobuf import config_pb2
18from tensorflow.python.debug.lib import debug_utils
19from tensorflow.python.debug.wrappers import dumping_wrapper
20from tensorflow.python.debug.wrappers import framework
21from tensorflow.python.debug.wrappers import grpc_wrapper
22from tensorflow.python.debug.wrappers import local_cli_wrapper
23from tensorflow.python.training import session_run_hook
24
25
26class LocalCLIDebugHook(session_run_hook.SessionRunHook):
27  """Command-line-interface debugger hook.
28
29  Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
30  `tf.estimator.Estimator`s. Provides a substitute for
31  `tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly
32  available.
33  """
34
35  def __init__(self,
36               ui_type="curses",
37               dump_root=None,
38               thread_name_filter=None,
39               config_file_path=None):
40    """Create a local debugger command-line interface (CLI) hook.
41
42    Args:
43      ui_type: (`str`) requested user-interface type. Currently supported:
44        (curses | readline).
45      dump_root: (`str`) optional path to the dump root directory. Must be a
46        directory that does not exist or an empty directory. If the directory
47        does not exist, it will be created by the debugger core during debug
48        `run()` calls and removed afterwards.
49      thread_name_filter: Regular-expression white list for threads on which the
50        wrapper session will be active. See doc of `BaseDebugWrapperSession` for
51        more details.
52      config_file_path: Optional override to the default configuration file
53        path, which is at `${HOME}/.tfdbg_config`.
54    """
55
56    self._ui_type = ui_type
57    self._dump_root = dump_root
58    self._thread_name_filter = thread_name_filter
59    self._session_wrapper = None
60    self._pending_tensor_filters = {}
61    self._config_file_path = config_file_path
62
63  def add_tensor_filter(self, filter_name, tensor_filter):
64    """Add a tensor filter.
65
66    See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
67    Override default behavior to accommodate the possibility of this method
68    being
69    called prior to the initialization of the underlying
70    `LocalCLIDebugWrapperSession` object.
71
72    Args:
73      filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()`
74        for details.
75      tensor_filter: See doc of
76        `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
77    """
78
79    if self._session_wrapper:
80      self._session_wrapper.add_tensor_filter(filter_name, tensor_filter)
81    else:
82      self._pending_tensor_filters[filter_name] = tensor_filter
83
84  def begin(self):
85    pass
86
87  def before_run(self, run_context):
88    if not self._session_wrapper:
89      self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession(
90          run_context.session,
91          ui_type=self._ui_type,
92          dump_root=self._dump_root,
93          thread_name_filter=self._thread_name_filter,
94          config_file_path=self._config_file_path)
95
96      # Actually register tensor filters registered prior to the construction
97      # of the underlying LocalCLIDebugWrapperSession object.
98      for filter_name in self._pending_tensor_filters:
99        self._session_wrapper.add_tensor_filter(
100            filter_name, self._pending_tensor_filters[filter_name])
101
102    # Increment run call counter.
103    self._session_wrapper.increment_run_call_count()
104
105    # Adapt run_context to an instance of OnRunStartRequest for invoking
106    # superclass on_run_start().
107    on_run_start_request = framework.OnRunStartRequest(
108        run_context.original_args.fetches, run_context.original_args.feed_dict,
109        None, None, self._session_wrapper.run_call_count)
110
111    on_run_start_response = self._session_wrapper.on_run_start(
112        on_run_start_request)
113    self._performed_action = on_run_start_response.action
114
115    run_args = session_run_hook.SessionRunArgs(
116        None, feed_dict=None, options=config_pb2.RunOptions())
117    if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
118      # pylint: disable=protected-access
119      self._session_wrapper._decorate_run_options_for_debug(
120          run_args.options,
121          on_run_start_response.debug_urls,
122          debug_ops=on_run_start_response.debug_ops,
123          node_name_regex_allowlist=(
124              on_run_start_response.node_name_regex_allowlist),
125          op_type_regex_allowlist=(
126              on_run_start_response.op_type_regex_allowlist),
127          tensor_dtype_regex_allowlist=(
128              on_run_start_response.tensor_dtype_regex_allowlist),
129          tolerate_debug_op_creation_failures=(
130              on_run_start_response.tolerate_debug_op_creation_failures))
131      # pylint: enable=protected-access
132    elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
133      # pylint: disable=protected-access
134      self._session_wrapper._decorate_run_options_for_profile(run_args.options)
135      # pylint: enable=protected-access
136
137    return run_args
138
139  def after_run(self, run_context, run_values):
140    # Adapt run_context and run_values to OnRunEndRequest and invoke superclass
141    # on_run_end()
142    on_run_end_request = framework.OnRunEndRequest(self._performed_action,
143                                                   run_values.run_metadata)
144    self._session_wrapper.on_run_end(on_run_end_request)
145
146
147class DumpingDebugHook(session_run_hook.SessionRunHook):
148  """A debugger hook that dumps debug data to filesystem.
149
150  Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
151  `tf.estimator.Estimator`s.
152  """
153
154  def __init__(self,
155               session_root,
156               watch_fn=None,
157               thread_name_filter=None,
158               log_usage=True):
159    """Create a local debugger command-line interface (CLI) hook.
160
161    Args:
162      session_root: See doc of
163        `dumping_wrapper.DumpingDebugWrapperSession.__init__`.
164      watch_fn: See doc of
165        `dumping_wrapper.DumpingDebugWrapperSession.__init__`.
166      thread_name_filter: Regular-expression white list for threads on which the
167        wrapper session will be active. See doc of `BaseDebugWrapperSession` for
168        more details.
169      log_usage: (bool) Whether usage is to be logged.
170    """
171
172    self._session_root = session_root
173    self._watch_fn = watch_fn
174    self._thread_name_filter = thread_name_filter
175    self._log_usage = log_usage
176    self._session_wrapper = None
177
178  def begin(self):
179    pass
180
181  def before_run(self, run_context):
182    reset_disk_byte_usage = False
183    if not self._session_wrapper:
184      self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
185          run_context.session,
186          self._session_root,
187          watch_fn=self._watch_fn,
188          thread_name_filter=self._thread_name_filter,
189          log_usage=self._log_usage)
190      reset_disk_byte_usage = True
191
192    self._session_wrapper.increment_run_call_count()
193
194    # pylint: disable=protected-access
195    debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config(
196        run_context.original_args.fetches, run_context.original_args.feed_dict)
197    # pylint: enable=protected-access
198    run_options = config_pb2.RunOptions()
199    debug_utils.watch_graph(
200        run_options,
201        run_context.session.graph,
202        debug_urls=debug_urls,
203        debug_ops=watch_options.debug_ops,
204        node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
205        op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
206        tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
207        tolerate_debug_op_creation_failures=(
208            watch_options.tolerate_debug_op_creation_failures),
209        reset_disk_byte_usage=reset_disk_byte_usage)
210
211    run_args = session_run_hook.SessionRunArgs(
212        None, feed_dict=None, options=run_options)
213    return run_args
214
215  def after_run(self, run_context, run_values):
216    pass
217
218
219class GrpcDebugHook(session_run_hook.SessionRunHook):
220  """A hook that streams debugger-related events to any grpc_debug_server.
221
222  For example, the debugger data server is a grpc_debug_server. The debugger
223  data server writes debugger-related events it receives via GRPC to logdir.
224  This enables debugging features in Tensorboard such as health pills.
225
226  When the arguments of debug_utils.watch_graph changes, strongly consider
227  changing arguments here too so that features are available to tflearn users.
228
229  Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
230  `tf.estimator.Estimator`s.
231  """
232
233  def __init__(self,
234               grpc_debug_server_addresses,
235               watch_fn=None,
236               thread_name_filter=None,
237               log_usage=True):
238    """Constructs a GrpcDebugHook.
239
240    Args:
241      grpc_debug_server_addresses: (`list` of `str`) A list of the gRPC debug
242        server addresses, in the format of <host:port>, with or without the
243        "grpc://" prefix. For example: ["localhost:7000", "192.168.0.2:8000"]
244      watch_fn: A function that allows for customizing which ops to watch at
245        which specific steps. See doc of
246        `dumping_wrapper.DumpingDebugWrapperSession.__init__` for details.
247      thread_name_filter: Regular-expression white list for threads on which the
248        wrapper session will be active. See doc of `BaseDebugWrapperSession` for
249        more details.
250      log_usage: (bool) Whether usage is to be logged.
251    """
252    self._grpc_debug_wrapper_session = None
253    self._thread_name_filter = thread_name_filter
254    self._grpc_debug_server_addresses = (
255        grpc_debug_server_addresses
256        if isinstance(grpc_debug_server_addresses, list) else
257        [grpc_debug_server_addresses])
258
259    self._watch_fn = watch_fn
260    self._log_usage = log_usage
261
262  def before_run(self, run_context):
263    """Called right before a session is run.
264
265    Args:
266      run_context: A session_run_hook.SessionRunContext. Encapsulates
267        information on the run.
268
269    Returns:
270      A session_run_hook.SessionRunArgs object.
271    """
272
273    if not self._grpc_debug_wrapper_session:
274      self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession(
275          run_context.session,
276          self._grpc_debug_server_addresses,
277          watch_fn=self._watch_fn,
278          thread_name_filter=self._thread_name_filter,
279          log_usage=self._log_usage)
280
281    fetches = run_context.original_args.fetches
282    feed_dict = run_context.original_args.feed_dict
283    watch_options = self._watch_fn(fetches, feed_dict)
284    run_options = config_pb2.RunOptions()
285    debug_utils.watch_graph(
286        run_options,
287        run_context.session.graph,
288        debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls(
289            fetches, feed_dict),
290        debug_ops=watch_options.debug_ops,
291        node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
292        op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
293        tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
294        tolerate_debug_op_creation_failures=(
295            watch_options.tolerate_debug_op_creation_failures))
296
297    return session_run_hook.SessionRunArgs(
298        None, feed_dict=None, options=run_options)
299
300
301class TensorBoardDebugHook(GrpcDebugHook):
302  """A tfdbg hook that can be used with TensorBoard Debugger Plugin.
303
304  This hook is the same as `GrpcDebugHook`, except that it uses a predefined
305    `watch_fn` that
306    1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to
307        `True`, to allow the interactive enabling and disabling of tensor
308       breakpoints.
309    2) watches all tensors in the graph.
310  This saves the need for the user to define a `watch_fn`.
311  """
312
313  def __init__(self,
314               grpc_debug_server_addresses,
315               thread_name_filter=None,
316               send_traceback_and_source_code=True,
317               log_usage=True):
318    """Constructor of TensorBoardDebugHook.
319
320    Args:
321      grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a
322        `str` or a `list` of `str`s. E.g., "localhost:2333",
323        "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"].
324      thread_name_filter: Optional filter for thread names.
325      send_traceback_and_source_code: Whether traceback of graph elements and
326        the source code are to be sent to the debug server(s).
327      log_usage: Whether the usage of this class is to be logged (if
328        applicable).
329    """
330
331    def _gated_grpc_watch_fn(fetches, feeds):
332      del fetches, feeds  # Unused.
333      return framework.WatchOptions(
334          debug_ops=["DebugIdentity(gated_grpc=true)"])
335
336    super(TensorBoardDebugHook, self).__init__(
337        grpc_debug_server_addresses,
338        watch_fn=_gated_grpc_watch_fn,
339        thread_name_filter=thread_name_filter,
340        log_usage=log_usage)
341
342    self._grpc_debug_server_addresses = grpc_debug_server_addresses
343    self._send_traceback_and_source_code = send_traceback_and_source_code
344    self._sent_graph_version = -1
345    grpc_wrapper.register_signal_handler()
346
347  def before_run(self, run_context):
348    if self._send_traceback_and_source_code:
349      self._sent_graph_version = grpc_wrapper.publish_traceback(
350          self._grpc_debug_server_addresses, run_context.session.graph,
351          run_context.original_args.feed_dict,
352          run_context.original_args.fetches, self._sent_graph_version)
353    return super(TensorBoardDebugHook, self).before_run(run_context)
354