1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Debugger wrapper session that dumps debug data to file:// URLs.""" 16import os 17import threading 18import time 19 20# Google-internal import(s). 21from tensorflow.core.util import event_pb2 22from tensorflow.python.debug.lib import debug_data 23from tensorflow.python.debug.wrappers import framework 24from tensorflow.python.platform import gfile 25 26 27class DumpingDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): 28 """Debug Session wrapper that dumps debug data to filesystem.""" 29 30 def __init__(self, 31 sess, 32 session_root, 33 watch_fn=None, 34 thread_name_filter=None, 35 pass_through_operrors=None, 36 log_usage=True): 37 """Constructor of DumpingDebugWrapperSession. 38 39 Args: 40 sess: The TensorFlow `Session` object being wrapped. 41 session_root: (`str`) Path to the session root directory. Must be a 42 directory that does not exist or an empty directory. If the directory 43 does not exist, it will be created by the debugger core during debug 44 `tf.Session.run` 45 calls. 46 As the `run()` calls occur, subdirectories will be added to 47 `session_root`. The subdirectories' names has the following pattern: 48 run_<epoch_time_stamp>_<zero_based_run_counter> 49 E.g., run_1480734393835964_ad4c953a85444900ae79fc1b652fb324 50 watch_fn: (`Callable`) A Callable that can be used to define per-run 51 debug ops and watched tensors. See the doc of 52 `NonInteractiveDebugWrapperSession.__init__()` for details. 53 thread_name_filter: Regular-expression white list for threads on which the 54 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 55 more details. 56 pass_through_operrors: If true, all captured OpErrors will be 57 propagated. By default this captures all OpErrors. 58 log_usage: (`bool`) whether the usage of this class is to be logged. 59 60 Raises: 61 ValueError: If `session_root` is an existing and non-empty directory or 62 if `session_root` is a file. 63 """ 64 65 if log_usage: 66 pass # No logging for open-source. 67 68 framework.NonInteractiveDebugWrapperSession.__init__( 69 self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter, 70 pass_through_operrors=pass_through_operrors) 71 72 session_root = os.path.expanduser(session_root) 73 if gfile.Exists(session_root): 74 if not gfile.IsDirectory(session_root): 75 raise ValueError( 76 "session_root path points to a file: %s" % session_root) 77 elif gfile.ListDirectory(session_root): 78 raise ValueError( 79 "session_root path points to a non-empty directory: %s" % 80 session_root) 81 else: 82 gfile.MakeDirs(session_root) 83 self._session_root = session_root 84 85 self._run_counter = 0 86 self._run_counter_lock = threading.Lock() 87 88 def prepare_run_debug_urls(self, fetches, feed_dict): 89 """Implementation of abstract method in superclass. 90 91 See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()` 92 for details. This implementation creates a run-specific subdirectory under 93 self._session_root and stores information regarding run `fetches` and 94 `feed_dict.keys()` in the subdirectory. 95 96 Args: 97 fetches: Same as the `fetches` argument to `Session.run()` 98 feed_dict: Same as the `feed_dict` argument to `Session.run()` 99 100 Returns: 101 debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in 102 this `Session.run()` call. 103 """ 104 105 # Add a UUID to accommodate the possibility of concurrent run() calls. 106 self._run_counter_lock.acquire() 107 run_dir = os.path.join(self._session_root, "run_%d_%d" % 108 (int(time.time() * 1e6), self._run_counter)) 109 self._run_counter += 1 110 self._run_counter_lock.release() 111 gfile.MkDir(run_dir) 112 113 fetches_event = event_pb2.Event() 114 fetches_event.log_message.message = repr(fetches) 115 fetches_path = os.path.join( 116 run_dir, 117 debug_data.METADATA_FILE_PREFIX + debug_data.FETCHES_INFO_FILE_TAG) 118 with gfile.Open(os.path.join(fetches_path), "wb") as f: 119 f.write(fetches_event.SerializeToString()) 120 121 feed_keys_event = event_pb2.Event() 122 feed_keys_event.log_message.message = (repr(feed_dict.keys()) if feed_dict 123 else repr(feed_dict)) 124 125 feed_keys_path = os.path.join( 126 run_dir, 127 debug_data.METADATA_FILE_PREFIX + debug_data.FEED_KEYS_INFO_FILE_TAG) 128 with gfile.Open(os.path.join(feed_keys_path), "wb") as f: 129 f.write(feed_keys_event.SerializeToString()) 130 131 return ["file://" + run_dir] 132