1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit Tests for classes in dumping_wrapper.py.""" 16import glob 17import os 18import tempfile 19import threading 20 21from tensorflow.python.client import session 22from tensorflow.python.debug.lib import debug_data 23from tensorflow.python.debug.wrappers import dumping_wrapper 24from tensorflow.python.debug.wrappers import framework 25from tensorflow.python.debug.wrappers import hooks 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import test_util 30from tensorflow.python.lib.io import file_io 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import state_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import gfile 35from tensorflow.python.platform import googletest 36from tensorflow.python.training import monitored_session 37 38 39@test_util.run_v1_only("b/120545219") 40class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): 41 42 def setUp(self): 43 self.session_root = tempfile.mkdtemp() 44 45 self.v = variables.VariableV1(10.0, dtype=dtypes.float32, name="v") 46 self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta") 47 self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta") 48 self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v") 49 self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v") 50 51 self.ph = array_ops.placeholder(dtypes.float32, shape=(), name="ph") 52 self.inc_w_ph = state_ops.assign_add(self.v, self.ph, name="inc_w_ph") 53 54 self.sess = session.Session() 55 self.sess.run(self.v.initializer) 56 57 def tearDown(self): 58 ops.reset_default_graph() 59 if os.path.isdir(self.session_root): 60 file_io.delete_recursively(self.session_root) 61 62 def _assert_correct_run_subdir_naming(self, run_subdir): 63 self.assertStartsWith(run_subdir, "run_") 64 self.assertEqual(2, run_subdir.count("_")) 65 self.assertGreater(int(run_subdir.split("_")[1]), 0) 66 67 def testConstructWrapperWithExistingNonEmptyRootDirRaisesException(self): 68 dir_path = os.path.join(self.session_root, "foo") 69 os.mkdir(dir_path) 70 self.assertTrue(os.path.isdir(dir_path)) 71 72 with self.assertRaisesRegex( 73 ValueError, "session_root path points to a non-empty directory"): 74 dumping_wrapper.DumpingDebugWrapperSession( 75 session.Session(), session_root=self.session_root, log_usage=False) 76 77 def testConstructWrapperWithExistingFileDumpRootRaisesException(self): 78 file_path = os.path.join(self.session_root, "foo") 79 open(file_path, "a").close() # Create the file 80 self.assertTrue(gfile.Exists(file_path)) 81 self.assertFalse(gfile.IsDirectory(file_path)) 82 with self.assertRaisesRegex(ValueError, 83 "session_root path points to a file"): 84 dumping_wrapper.DumpingDebugWrapperSession( 85 session.Session(), session_root=file_path, log_usage=False) 86 87 def testConstructWrapperWithNonexistentSessionRootCreatesDirectory(self): 88 new_dir_path = os.path.join(tempfile.mkdtemp(), "new_dir") 89 dumping_wrapper.DumpingDebugWrapperSession( 90 session.Session(), session_root=new_dir_path, log_usage=False) 91 self.assertTrue(gfile.IsDirectory(new_dir_path)) 92 # Cleanup. 93 gfile.DeleteRecursively(new_dir_path) 94 95 def testDumpingOnASingleRunWorks(self): 96 sess = dumping_wrapper.DumpingDebugWrapperSession( 97 self.sess, session_root=self.session_root, log_usage=False) 98 sess.run(self.inc_v) 99 100 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 101 self.assertEqual(1, len(dump_dirs)) 102 103 self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0])) 104 dump = debug_data.DebugDumpDir(dump_dirs[0]) 105 self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) 106 107 self.assertEqual(repr(self.inc_v), dump.run_fetches_info) 108 self.assertEqual(repr(None), dump.run_feed_keys_info) 109 110 def testDumpingOnASingleRunWorksWithRelativePathForDebugDumpDir(self): 111 sess = dumping_wrapper.DumpingDebugWrapperSession( 112 self.sess, session_root=self.session_root, log_usage=False) 113 sess.run(self.inc_v) 114 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 115 cwd = os.getcwd() 116 try: 117 os.chdir(self.session_root) 118 dump = debug_data.DebugDumpDir( 119 os.path.relpath(dump_dirs[0], self.session_root)) 120 self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) 121 finally: 122 os.chdir(cwd) 123 124 def testDumpingOnASingleRunWithFeedDictWorks(self): 125 sess = dumping_wrapper.DumpingDebugWrapperSession( 126 self.sess, session_root=self.session_root, log_usage=False) 127 feed_dict = {self.ph: 3.2} 128 sess.run(self.inc_w_ph, feed_dict=feed_dict) 129 130 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 131 self.assertEqual(1, len(dump_dirs)) 132 133 self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0])) 134 dump = debug_data.DebugDumpDir(dump_dirs[0]) 135 self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) 136 137 self.assertEqual(repr(self.inc_w_ph), dump.run_fetches_info) 138 self.assertEqual(repr(feed_dict.keys()), dump.run_feed_keys_info) 139 140 def testDumpingOnMultipleRunsWorks(self): 141 sess = dumping_wrapper.DumpingDebugWrapperSession( 142 self.sess, session_root=self.session_root, log_usage=False) 143 for _ in range(3): 144 sess.run(self.inc_v) 145 146 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 147 dump_dirs = sorted( 148 dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) 149 self.assertEqual(3, len(dump_dirs)) 150 for i, dump_dir in enumerate(dump_dirs): 151 self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) 152 dump = debug_data.DebugDumpDir(dump_dir) 153 self.assertAllClose([10.0 + 1.0 * i], 154 dump.get_tensors("v", 0, "DebugIdentity")) 155 self.assertEqual(repr(self.inc_v), dump.run_fetches_info) 156 self.assertEqual(repr(None), dump.run_feed_keys_info) 157 158 def testUsingNonCallableAsWatchFnRaisesTypeError(self): 159 bad_watch_fn = "bad_watch_fn" 160 with self.assertRaisesRegex(TypeError, "watch_fn is not callable"): 161 dumping_wrapper.DumpingDebugWrapperSession( 162 self.sess, 163 session_root=self.session_root, 164 watch_fn=bad_watch_fn, 165 log_usage=False) 166 167 def testDumpingWithLegacyWatchFnOnFetchesWorks(self): 168 """Use a watch_fn that returns different allowlists for different runs.""" 169 170 def watch_fn(fetches, feeds): 171 del feeds 172 # A watch_fn that picks fetch name. 173 if fetches.name == "inc_v:0": 174 # If inc_v, watch everything. 175 return "DebugIdentity", r".*", r".*" 176 else: 177 # If dec_v, watch nothing. 178 return "DebugIdentity", r"$^", r"$^" 179 180 sess = dumping_wrapper.DumpingDebugWrapperSession( 181 self.sess, 182 session_root=self.session_root, 183 watch_fn=watch_fn, 184 log_usage=False) 185 186 for _ in range(3): 187 sess.run(self.inc_v) 188 sess.run(self.dec_v) 189 190 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 191 dump_dirs = sorted( 192 dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) 193 self.assertEqual(6, len(dump_dirs)) 194 195 for i, dump_dir in enumerate(dump_dirs): 196 self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) 197 dump = debug_data.DebugDumpDir(dump_dir) 198 if i % 2 == 0: 199 self.assertGreater(dump.size, 0) 200 self.assertAllClose([10.0 - 0.4 * (i / 2)], 201 dump.get_tensors("v", 0, "DebugIdentity")) 202 self.assertEqual(repr(self.inc_v), dump.run_fetches_info) 203 self.assertEqual(repr(None), dump.run_feed_keys_info) 204 else: 205 self.assertEqual(0, dump.size) 206 self.assertEqual(repr(self.dec_v), dump.run_fetches_info) 207 self.assertEqual(repr(None), dump.run_feed_keys_info) 208 209 def testDumpingWithLegacyWatchFnWithNonDefaultDebugOpsWorks(self): 210 """Use a watch_fn that specifies non-default debug ops.""" 211 212 def watch_fn(fetches, feeds): 213 del fetches, feeds 214 return ["DebugIdentity", "DebugNumericSummary"], r".*", r".*" 215 216 sess = dumping_wrapper.DumpingDebugWrapperSession( 217 self.sess, 218 session_root=self.session_root, 219 watch_fn=watch_fn, 220 log_usage=False) 221 222 sess.run(self.inc_v) 223 224 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 225 self.assertEqual(1, len(dump_dirs)) 226 dump = debug_data.DebugDumpDir(dump_dirs[0]) 227 228 self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) 229 self.assertEqual(14, 230 len(dump.get_tensors("v", 0, "DebugNumericSummary")[0])) 231 232 def testDumpingWithWatchFnWithNonDefaultDebugOpsWorks(self): 233 """Use a watch_fn that specifies non-default debug ops.""" 234 235 def watch_fn(fetches, feeds): 236 del fetches, feeds 237 return framework.WatchOptions( 238 debug_ops=["DebugIdentity", "DebugNumericSummary"], 239 node_name_regex_allowlist=r"^v.*", 240 op_type_regex_allowlist=r".*", 241 tensor_dtype_regex_allowlist=".*_ref") 242 243 sess = dumping_wrapper.DumpingDebugWrapperSession( 244 self.sess, 245 session_root=self.session_root, 246 watch_fn=watch_fn, 247 log_usage=False) 248 249 sess.run(self.inc_v) 250 251 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 252 self.assertEqual(1, len(dump_dirs)) 253 dump = debug_data.DebugDumpDir(dump_dirs[0]) 254 255 self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) 256 self.assertEqual(14, 257 len(dump.get_tensors("v", 0, "DebugNumericSummary")[0])) 258 259 dumped_nodes = [dump.node_name for dump in dump.dumped_tensor_data] 260 self.assertNotIn("inc_v", dumped_nodes) 261 self.assertNotIn("delta", dumped_nodes) 262 263 def testDumpingDebugHookWithoutWatchFnWorks(self): 264 dumping_hook = hooks.DumpingDebugHook(self.session_root, log_usage=False) 265 mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) 266 mon_sess.run(self.inc_v) 267 268 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 269 self.assertEqual(1, len(dump_dirs)) 270 271 self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0])) 272 dump = debug_data.DebugDumpDir(dump_dirs[0]) 273 self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) 274 275 self.assertEqual(repr(self.inc_v), dump.run_fetches_info) 276 self.assertEqual(repr(None), dump.run_feed_keys_info) 277 278 def testDumpingDebugHookWithStatefulWatchFnWorks(self): 279 watch_fn_state = {"run_counter": 0} 280 281 def counting_watch_fn(fetches, feed_dict): 282 del fetches, feed_dict 283 watch_fn_state["run_counter"] += 1 284 if watch_fn_state["run_counter"] % 2 == 1: 285 # If odd-index run (1-based), watch every ref-type tensor. 286 return framework.WatchOptions( 287 debug_ops="DebugIdentity", tensor_dtype_regex_allowlist=".*_ref") 288 else: 289 # If even-index run, watch nothing. 290 return framework.WatchOptions( 291 debug_ops="DebugIdentity", 292 node_name_regex_allowlist=r"^$", 293 op_type_regex_allowlist=r"^$") 294 295 dumping_hook = hooks.DumpingDebugHook( 296 self.session_root, watch_fn=counting_watch_fn, log_usage=False) 297 mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) 298 for _ in range(4): 299 mon_sess.run(self.inc_v) 300 301 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 302 dump_dirs = sorted( 303 dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) 304 self.assertEqual(4, len(dump_dirs)) 305 306 for i, dump_dir in enumerate(dump_dirs): 307 self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) 308 dump = debug_data.DebugDumpDir(dump_dir) 309 if i % 2 == 0: 310 self.assertAllClose([10.0 + 1.0 * i], 311 dump.get_tensors("v", 0, "DebugIdentity")) 312 self.assertNotIn("delta", 313 [datum.node_name for datum in dump.dumped_tensor_data]) 314 else: 315 self.assertEqual(0, dump.size) 316 317 self.assertEqual(repr(self.inc_v), dump.run_fetches_info) 318 self.assertEqual(repr(None), dump.run_feed_keys_info) 319 320 def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self): 321 watch_fn_state = {"run_counter": 0} 322 323 def counting_watch_fn(fetches, feed_dict): 324 del fetches, feed_dict 325 watch_fn_state["run_counter"] += 1 326 if watch_fn_state["run_counter"] % 2 == 1: 327 # If odd-index run (1-based), watch everything. 328 return "DebugIdentity", r".*", r".*" 329 else: 330 # If even-index run, watch nothing. 331 return "DebugIdentity", r"$^", r"$^" 332 333 dumping_hook = hooks.DumpingDebugHook( 334 self.session_root, watch_fn=counting_watch_fn, log_usage=False) 335 mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) 336 for _ in range(4): 337 mon_sess.run(self.inc_v) 338 339 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 340 dump_dirs = sorted( 341 dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) 342 self.assertEqual(4, len(dump_dirs)) 343 344 for i, dump_dir in enumerate(dump_dirs): 345 self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) 346 dump = debug_data.DebugDumpDir(dump_dir) 347 if i % 2 == 0: 348 self.assertAllClose([10.0 + 1.0 * i], 349 dump.get_tensors("v", 0, "DebugIdentity")) 350 else: 351 self.assertEqual(0, dump.size) 352 353 self.assertEqual(repr(self.inc_v), dump.run_fetches_info) 354 self.assertEqual(repr(None), dump.run_feed_keys_info) 355 356 def testDumpingFromMultipleThreadsObeysThreadNameFilter(self): 357 sess = dumping_wrapper.DumpingDebugWrapperSession( 358 self.sess, session_root=self.session_root, log_usage=False, 359 thread_name_filter=r"MainThread$") 360 361 self.assertAllClose(1.0, sess.run(self.delta)) 362 child_thread_result = [] 363 def child_thread_job(): 364 child_thread_result.append(sess.run(self.eta)) 365 366 thread = threading.Thread(name="ChildThread", target=child_thread_job) 367 thread.start() 368 thread.join() 369 self.assertAllClose([-1.4], child_thread_result) 370 371 dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) 372 self.assertEqual(1, len(dump_dirs)) 373 dump = debug_data.DebugDumpDir(dump_dirs[0]) 374 self.assertEqual(1, dump.size) 375 self.assertEqual("delta", dump.dumped_tensor_data[0].node_name) 376 377 def testDumpingWrapperWithEmptyFetchWorks(self): 378 sess = dumping_wrapper.DumpingDebugWrapperSession( 379 self.sess, session_root=self.session_root, log_usage=False) 380 sess.run([]) 381 382 383if __name__ == "__main__": 384 googletest.main() 385