1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""tfdbg CLI as SessionRunHook.""" 16 17from tensorflow.core.protobuf import config_pb2 18from tensorflow.python.debug.lib import debug_utils 19from tensorflow.python.debug.wrappers import dumping_wrapper 20from tensorflow.python.debug.wrappers import framework 21from tensorflow.python.debug.wrappers import grpc_wrapper 22from tensorflow.python.debug.wrappers import local_cli_wrapper 23from tensorflow.python.training import session_run_hook 24 25 26class LocalCLIDebugHook(session_run_hook.SessionRunHook): 27 """Command-line-interface debugger hook. 28 29 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and 30 `tf.estimator.Estimator`s. Provides a substitute for 31 `tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly 32 available. 33 """ 34 35 def __init__(self, 36 ui_type="curses", 37 dump_root=None, 38 thread_name_filter=None, 39 config_file_path=None): 40 """Create a local debugger command-line interface (CLI) hook. 41 42 Args: 43 ui_type: (`str`) requested user-interface type. Currently supported: 44 (curses | readline). 45 dump_root: (`str`) optional path to the dump root directory. Must be a 46 directory that does not exist or an empty directory. If the directory 47 does not exist, it will be created by the debugger core during debug 48 `run()` calls and removed afterwards. 49 thread_name_filter: Regular-expression white list for threads on which the 50 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 51 more details. 52 config_file_path: Optional override to the default configuration file 53 path, which is at `${HOME}/.tfdbg_config`. 54 """ 55 56 self._ui_type = ui_type 57 self._dump_root = dump_root 58 self._thread_name_filter = thread_name_filter 59 self._session_wrapper = None 60 self._pending_tensor_filters = {} 61 self._config_file_path = config_file_path 62 63 def add_tensor_filter(self, filter_name, tensor_filter): 64 """Add a tensor filter. 65 66 See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. 67 Override default behavior to accommodate the possibility of this method 68 being 69 called prior to the initialization of the underlying 70 `LocalCLIDebugWrapperSession` object. 71 72 Args: 73 filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` 74 for details. 75 tensor_filter: See doc of 76 `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. 77 """ 78 79 if self._session_wrapper: 80 self._session_wrapper.add_tensor_filter(filter_name, tensor_filter) 81 else: 82 self._pending_tensor_filters[filter_name] = tensor_filter 83 84 def begin(self): 85 pass 86 87 def before_run(self, run_context): 88 if not self._session_wrapper: 89 self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession( 90 run_context.session, 91 ui_type=self._ui_type, 92 dump_root=self._dump_root, 93 thread_name_filter=self._thread_name_filter, 94 config_file_path=self._config_file_path) 95 96 # Actually register tensor filters registered prior to the construction 97 # of the underlying LocalCLIDebugWrapperSession object. 98 for filter_name in self._pending_tensor_filters: 99 self._session_wrapper.add_tensor_filter( 100 filter_name, self._pending_tensor_filters[filter_name]) 101 102 # Increment run call counter. 103 self._session_wrapper.increment_run_call_count() 104 105 # Adapt run_context to an instance of OnRunStartRequest for invoking 106 # superclass on_run_start(). 107 on_run_start_request = framework.OnRunStartRequest( 108 run_context.original_args.fetches, run_context.original_args.feed_dict, 109 None, None, self._session_wrapper.run_call_count) 110 111 on_run_start_response = self._session_wrapper.on_run_start( 112 on_run_start_request) 113 self._performed_action = on_run_start_response.action 114 115 run_args = session_run_hook.SessionRunArgs( 116 None, feed_dict=None, options=config_pb2.RunOptions()) 117 if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: 118 # pylint: disable=protected-access 119 self._session_wrapper._decorate_run_options_for_debug( 120 run_args.options, 121 on_run_start_response.debug_urls, 122 debug_ops=on_run_start_response.debug_ops, 123 node_name_regex_allowlist=( 124 on_run_start_response.node_name_regex_allowlist), 125 op_type_regex_allowlist=( 126 on_run_start_response.op_type_regex_allowlist), 127 tensor_dtype_regex_allowlist=( 128 on_run_start_response.tensor_dtype_regex_allowlist), 129 tolerate_debug_op_creation_failures=( 130 on_run_start_response.tolerate_debug_op_creation_failures)) 131 # pylint: enable=protected-access 132 elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN: 133 # pylint: disable=protected-access 134 self._session_wrapper._decorate_run_options_for_profile(run_args.options) 135 # pylint: enable=protected-access 136 137 return run_args 138 139 def after_run(self, run_context, run_values): 140 # Adapt run_context and run_values to OnRunEndRequest and invoke superclass 141 # on_run_end() 142 on_run_end_request = framework.OnRunEndRequest(self._performed_action, 143 run_values.run_metadata) 144 self._session_wrapper.on_run_end(on_run_end_request) 145 146 147class DumpingDebugHook(session_run_hook.SessionRunHook): 148 """A debugger hook that dumps debug data to filesystem. 149 150 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and 151 `tf.estimator.Estimator`s. 152 """ 153 154 def __init__(self, 155 session_root, 156 watch_fn=None, 157 thread_name_filter=None, 158 log_usage=True): 159 """Create a local debugger command-line interface (CLI) hook. 160 161 Args: 162 session_root: See doc of 163 `dumping_wrapper.DumpingDebugWrapperSession.__init__`. 164 watch_fn: See doc of 165 `dumping_wrapper.DumpingDebugWrapperSession.__init__`. 166 thread_name_filter: Regular-expression white list for threads on which the 167 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 168 more details. 169 log_usage: (bool) Whether usage is to be logged. 170 """ 171 172 self._session_root = session_root 173 self._watch_fn = watch_fn 174 self._thread_name_filter = thread_name_filter 175 self._log_usage = log_usage 176 self._session_wrapper = None 177 178 def begin(self): 179 pass 180 181 def before_run(self, run_context): 182 reset_disk_byte_usage = False 183 if not self._session_wrapper: 184 self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession( 185 run_context.session, 186 self._session_root, 187 watch_fn=self._watch_fn, 188 thread_name_filter=self._thread_name_filter, 189 log_usage=self._log_usage) 190 reset_disk_byte_usage = True 191 192 self._session_wrapper.increment_run_call_count() 193 194 # pylint: disable=protected-access 195 debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config( 196 run_context.original_args.fetches, run_context.original_args.feed_dict) 197 # pylint: enable=protected-access 198 run_options = config_pb2.RunOptions() 199 debug_utils.watch_graph( 200 run_options, 201 run_context.session.graph, 202 debug_urls=debug_urls, 203 debug_ops=watch_options.debug_ops, 204 node_name_regex_allowlist=watch_options.node_name_regex_allowlist, 205 op_type_regex_allowlist=watch_options.op_type_regex_allowlist, 206 tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist, 207 tolerate_debug_op_creation_failures=( 208 watch_options.tolerate_debug_op_creation_failures), 209 reset_disk_byte_usage=reset_disk_byte_usage) 210 211 run_args = session_run_hook.SessionRunArgs( 212 None, feed_dict=None, options=run_options) 213 return run_args 214 215 def after_run(self, run_context, run_values): 216 pass 217 218 219class GrpcDebugHook(session_run_hook.SessionRunHook): 220 """A hook that streams debugger-related events to any grpc_debug_server. 221 222 For example, the debugger data server is a grpc_debug_server. The debugger 223 data server writes debugger-related events it receives via GRPC to logdir. 224 This enables debugging features in Tensorboard such as health pills. 225 226 When the arguments of debug_utils.watch_graph changes, strongly consider 227 changing arguments here too so that features are available to tflearn users. 228 229 Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and 230 `tf.estimator.Estimator`s. 231 """ 232 233 def __init__(self, 234 grpc_debug_server_addresses, 235 watch_fn=None, 236 thread_name_filter=None, 237 log_usage=True): 238 """Constructs a GrpcDebugHook. 239 240 Args: 241 grpc_debug_server_addresses: (`list` of `str`) A list of the gRPC debug 242 server addresses, in the format of <host:port>, with or without the 243 "grpc://" prefix. For example: ["localhost:7000", "192.168.0.2:8000"] 244 watch_fn: A function that allows for customizing which ops to watch at 245 which specific steps. See doc of 246 `dumping_wrapper.DumpingDebugWrapperSession.__init__` for details. 247 thread_name_filter: Regular-expression white list for threads on which the 248 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 249 more details. 250 log_usage: (bool) Whether usage is to be logged. 251 """ 252 self._grpc_debug_wrapper_session = None 253 self._thread_name_filter = thread_name_filter 254 self._grpc_debug_server_addresses = ( 255 grpc_debug_server_addresses 256 if isinstance(grpc_debug_server_addresses, list) else 257 [grpc_debug_server_addresses]) 258 259 self._watch_fn = watch_fn 260 self._log_usage = log_usage 261 262 def before_run(self, run_context): 263 """Called right before a session is run. 264 265 Args: 266 run_context: A session_run_hook.SessionRunContext. Encapsulates 267 information on the run. 268 269 Returns: 270 A session_run_hook.SessionRunArgs object. 271 """ 272 273 if not self._grpc_debug_wrapper_session: 274 self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession( 275 run_context.session, 276 self._grpc_debug_server_addresses, 277 watch_fn=self._watch_fn, 278 thread_name_filter=self._thread_name_filter, 279 log_usage=self._log_usage) 280 281 fetches = run_context.original_args.fetches 282 feed_dict = run_context.original_args.feed_dict 283 watch_options = self._watch_fn(fetches, feed_dict) 284 run_options = config_pb2.RunOptions() 285 debug_utils.watch_graph( 286 run_options, 287 run_context.session.graph, 288 debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls( 289 fetches, feed_dict), 290 debug_ops=watch_options.debug_ops, 291 node_name_regex_allowlist=watch_options.node_name_regex_allowlist, 292 op_type_regex_allowlist=watch_options.op_type_regex_allowlist, 293 tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist, 294 tolerate_debug_op_creation_failures=( 295 watch_options.tolerate_debug_op_creation_failures)) 296 297 return session_run_hook.SessionRunArgs( 298 None, feed_dict=None, options=run_options) 299 300 301class TensorBoardDebugHook(GrpcDebugHook): 302 """A tfdbg hook that can be used with TensorBoard Debugger Plugin. 303 304 This hook is the same as `GrpcDebugHook`, except that it uses a predefined 305 `watch_fn` that 306 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to 307 `True`, to allow the interactive enabling and disabling of tensor 308 breakpoints. 309 2) watches all tensors in the graph. 310 This saves the need for the user to define a `watch_fn`. 311 """ 312 313 def __init__(self, 314 grpc_debug_server_addresses, 315 thread_name_filter=None, 316 send_traceback_and_source_code=True, 317 log_usage=True): 318 """Constructor of TensorBoardDebugHook. 319 320 Args: 321 grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a 322 `str` or a `list` of `str`s. E.g., "localhost:2333", 323 "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"]. 324 thread_name_filter: Optional filter for thread names. 325 send_traceback_and_source_code: Whether traceback of graph elements and 326 the source code are to be sent to the debug server(s). 327 log_usage: Whether the usage of this class is to be logged (if 328 applicable). 329 """ 330 331 def _gated_grpc_watch_fn(fetches, feeds): 332 del fetches, feeds # Unused. 333 return framework.WatchOptions( 334 debug_ops=["DebugIdentity(gated_grpc=true)"]) 335 336 super(TensorBoardDebugHook, self).__init__( 337 grpc_debug_server_addresses, 338 watch_fn=_gated_grpc_watch_fn, 339 thread_name_filter=thread_name_filter, 340 log_usage=log_usage) 341 342 self._grpc_debug_server_addresses = grpc_debug_server_addresses 343 self._send_traceback_and_source_code = send_traceback_and_source_code 344 self._sent_graph_version = -1 345 grpc_wrapper.register_signal_handler() 346 347 def before_run(self, run_context): 348 if self._send_traceback_and_source_code: 349 self._sent_graph_version = grpc_wrapper.publish_traceback( 350 self._grpc_debug_server_addresses, run_context.session.graph, 351 run_context.original_args.feed_dict, 352 run_context.original_args.fetches, self._sent_graph_version) 353 return super(TensorBoardDebugHook, self).before_run(run_context) 354