• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit Tests for classes in dumping_wrapper.py."""
16import glob
17import os
18import tempfile
19import threading
20
21from tensorflow.python.client import session
22from tensorflow.python.debug.lib import debug_data
23from tensorflow.python.debug.wrappers import dumping_wrapper
24from tensorflow.python.debug.wrappers import framework
25from tensorflow.python.debug.wrappers import hooks
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import test_util
30from tensorflow.python.lib.io import file_io
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import state_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import gfile
35from tensorflow.python.platform import googletest
36from tensorflow.python.training import monitored_session
37
38
39@test_util.run_v1_only("b/120545219")
40class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
41
42  def setUp(self):
43    self.session_root = tempfile.mkdtemp()
44
45    self.v = variables.VariableV1(10.0, dtype=dtypes.float32, name="v")
46    self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
47    self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
48    self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
49    self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v")
50
51    self.ph = array_ops.placeholder(dtypes.float32, shape=(), name="ph")
52    self.inc_w_ph = state_ops.assign_add(self.v, self.ph, name="inc_w_ph")
53
54    self.sess = session.Session()
55    self.sess.run(self.v.initializer)
56
57  def tearDown(self):
58    ops.reset_default_graph()
59    if os.path.isdir(self.session_root):
60      file_io.delete_recursively(self.session_root)
61
62  def _assert_correct_run_subdir_naming(self, run_subdir):
63    self.assertStartsWith(run_subdir, "run_")
64    self.assertEqual(2, run_subdir.count("_"))
65    self.assertGreater(int(run_subdir.split("_")[1]), 0)
66
67  def testConstructWrapperWithExistingNonEmptyRootDirRaisesException(self):
68    dir_path = os.path.join(self.session_root, "foo")
69    os.mkdir(dir_path)
70    self.assertTrue(os.path.isdir(dir_path))
71
72    with self.assertRaisesRegex(
73        ValueError, "session_root path points to a non-empty directory"):
74      dumping_wrapper.DumpingDebugWrapperSession(
75          session.Session(), session_root=self.session_root, log_usage=False)
76
77  def testConstructWrapperWithExistingFileDumpRootRaisesException(self):
78    file_path = os.path.join(self.session_root, "foo")
79    open(file_path, "a").close()  # Create the file
80    self.assertTrue(gfile.Exists(file_path))
81    self.assertFalse(gfile.IsDirectory(file_path))
82    with self.assertRaisesRegex(ValueError,
83                                "session_root path points to a file"):
84      dumping_wrapper.DumpingDebugWrapperSession(
85          session.Session(), session_root=file_path, log_usage=False)
86
87  def testConstructWrapperWithNonexistentSessionRootCreatesDirectory(self):
88    new_dir_path = os.path.join(tempfile.mkdtemp(), "new_dir")
89    dumping_wrapper.DumpingDebugWrapperSession(
90        session.Session(), session_root=new_dir_path, log_usage=False)
91    self.assertTrue(gfile.IsDirectory(new_dir_path))
92    # Cleanup.
93    gfile.DeleteRecursively(new_dir_path)
94
95  def testDumpingOnASingleRunWorks(self):
96    sess = dumping_wrapper.DumpingDebugWrapperSession(
97        self.sess, session_root=self.session_root, log_usage=False)
98    sess.run(self.inc_v)
99
100    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
101    self.assertEqual(1, len(dump_dirs))
102
103    self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0]))
104    dump = debug_data.DebugDumpDir(dump_dirs[0])
105    self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
106
107    self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
108    self.assertEqual(repr(None), dump.run_feed_keys_info)
109
110  def testDumpingOnASingleRunWorksWithRelativePathForDebugDumpDir(self):
111    sess = dumping_wrapper.DumpingDebugWrapperSession(
112        self.sess, session_root=self.session_root, log_usage=False)
113    sess.run(self.inc_v)
114    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
115    cwd = os.getcwd()
116    try:
117      os.chdir(self.session_root)
118      dump = debug_data.DebugDumpDir(
119          os.path.relpath(dump_dirs[0], self.session_root))
120      self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
121    finally:
122      os.chdir(cwd)
123
124  def testDumpingOnASingleRunWithFeedDictWorks(self):
125    sess = dumping_wrapper.DumpingDebugWrapperSession(
126        self.sess, session_root=self.session_root, log_usage=False)
127    feed_dict = {self.ph: 3.2}
128    sess.run(self.inc_w_ph, feed_dict=feed_dict)
129
130    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
131    self.assertEqual(1, len(dump_dirs))
132
133    self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0]))
134    dump = debug_data.DebugDumpDir(dump_dirs[0])
135    self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
136
137    self.assertEqual(repr(self.inc_w_ph), dump.run_fetches_info)
138    self.assertEqual(repr(feed_dict.keys()), dump.run_feed_keys_info)
139
140  def testDumpingOnMultipleRunsWorks(self):
141    sess = dumping_wrapper.DumpingDebugWrapperSession(
142        self.sess, session_root=self.session_root, log_usage=False)
143    for _ in range(3):
144      sess.run(self.inc_v)
145
146    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
147    dump_dirs = sorted(
148        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
149    self.assertEqual(3, len(dump_dirs))
150    for i, dump_dir in enumerate(dump_dirs):
151      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
152      dump = debug_data.DebugDumpDir(dump_dir)
153      self.assertAllClose([10.0 + 1.0 * i],
154                          dump.get_tensors("v", 0, "DebugIdentity"))
155      self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
156      self.assertEqual(repr(None), dump.run_feed_keys_info)
157
158  def testUsingNonCallableAsWatchFnRaisesTypeError(self):
159    bad_watch_fn = "bad_watch_fn"
160    with self.assertRaisesRegex(TypeError, "watch_fn is not callable"):
161      dumping_wrapper.DumpingDebugWrapperSession(
162          self.sess,
163          session_root=self.session_root,
164          watch_fn=bad_watch_fn,
165          log_usage=False)
166
167  def testDumpingWithLegacyWatchFnOnFetchesWorks(self):
168    """Use a watch_fn that returns different allowlists for different runs."""
169
170    def watch_fn(fetches, feeds):
171      del feeds
172      # A watch_fn that picks fetch name.
173      if fetches.name == "inc_v:0":
174        # If inc_v, watch everything.
175        return "DebugIdentity", r".*", r".*"
176      else:
177        # If dec_v, watch nothing.
178        return "DebugIdentity", r"$^", r"$^"
179
180    sess = dumping_wrapper.DumpingDebugWrapperSession(
181        self.sess,
182        session_root=self.session_root,
183        watch_fn=watch_fn,
184        log_usage=False)
185
186    for _ in range(3):
187      sess.run(self.inc_v)
188      sess.run(self.dec_v)
189
190    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
191    dump_dirs = sorted(
192        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
193    self.assertEqual(6, len(dump_dirs))
194
195    for i, dump_dir in enumerate(dump_dirs):
196      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
197      dump = debug_data.DebugDumpDir(dump_dir)
198      if i % 2 == 0:
199        self.assertGreater(dump.size, 0)
200        self.assertAllClose([10.0 - 0.4 * (i / 2)],
201                            dump.get_tensors("v", 0, "DebugIdentity"))
202        self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
203        self.assertEqual(repr(None), dump.run_feed_keys_info)
204      else:
205        self.assertEqual(0, dump.size)
206        self.assertEqual(repr(self.dec_v), dump.run_fetches_info)
207        self.assertEqual(repr(None), dump.run_feed_keys_info)
208
209  def testDumpingWithLegacyWatchFnWithNonDefaultDebugOpsWorks(self):
210    """Use a watch_fn that specifies non-default debug ops."""
211
212    def watch_fn(fetches, feeds):
213      del fetches, feeds
214      return ["DebugIdentity", "DebugNumericSummary"], r".*", r".*"
215
216    sess = dumping_wrapper.DumpingDebugWrapperSession(
217        self.sess,
218        session_root=self.session_root,
219        watch_fn=watch_fn,
220        log_usage=False)
221
222    sess.run(self.inc_v)
223
224    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
225    self.assertEqual(1, len(dump_dirs))
226    dump = debug_data.DebugDumpDir(dump_dirs[0])
227
228    self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
229    self.assertEqual(14,
230                     len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
231
232  def testDumpingWithWatchFnWithNonDefaultDebugOpsWorks(self):
233    """Use a watch_fn that specifies non-default debug ops."""
234
235    def watch_fn(fetches, feeds):
236      del fetches, feeds
237      return framework.WatchOptions(
238          debug_ops=["DebugIdentity", "DebugNumericSummary"],
239          node_name_regex_allowlist=r"^v.*",
240          op_type_regex_allowlist=r".*",
241          tensor_dtype_regex_allowlist=".*_ref")
242
243    sess = dumping_wrapper.DumpingDebugWrapperSession(
244        self.sess,
245        session_root=self.session_root,
246        watch_fn=watch_fn,
247        log_usage=False)
248
249    sess.run(self.inc_v)
250
251    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
252    self.assertEqual(1, len(dump_dirs))
253    dump = debug_data.DebugDumpDir(dump_dirs[0])
254
255    self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
256    self.assertEqual(14,
257                     len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
258
259    dumped_nodes = [dump.node_name for dump in dump.dumped_tensor_data]
260    self.assertNotIn("inc_v", dumped_nodes)
261    self.assertNotIn("delta", dumped_nodes)
262
263  def testDumpingDebugHookWithoutWatchFnWorks(self):
264    dumping_hook = hooks.DumpingDebugHook(self.session_root, log_usage=False)
265    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
266    mon_sess.run(self.inc_v)
267
268    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
269    self.assertEqual(1, len(dump_dirs))
270
271    self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0]))
272    dump = debug_data.DebugDumpDir(dump_dirs[0])
273    self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
274
275    self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
276    self.assertEqual(repr(None), dump.run_feed_keys_info)
277
278  def testDumpingDebugHookWithStatefulWatchFnWorks(self):
279    watch_fn_state = {"run_counter": 0}
280
281    def counting_watch_fn(fetches, feed_dict):
282      del fetches, feed_dict
283      watch_fn_state["run_counter"] += 1
284      if watch_fn_state["run_counter"] % 2 == 1:
285        # If odd-index run (1-based), watch every ref-type tensor.
286        return framework.WatchOptions(
287            debug_ops="DebugIdentity", tensor_dtype_regex_allowlist=".*_ref")
288      else:
289        # If even-index run, watch nothing.
290        return framework.WatchOptions(
291            debug_ops="DebugIdentity",
292            node_name_regex_allowlist=r"^$",
293            op_type_regex_allowlist=r"^$")
294
295    dumping_hook = hooks.DumpingDebugHook(
296        self.session_root, watch_fn=counting_watch_fn, log_usage=False)
297    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
298    for _ in range(4):
299      mon_sess.run(self.inc_v)
300
301    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
302    dump_dirs = sorted(
303        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
304    self.assertEqual(4, len(dump_dirs))
305
306    for i, dump_dir in enumerate(dump_dirs):
307      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
308      dump = debug_data.DebugDumpDir(dump_dir)
309      if i % 2 == 0:
310        self.assertAllClose([10.0 + 1.0 * i],
311                            dump.get_tensors("v", 0, "DebugIdentity"))
312        self.assertNotIn("delta",
313                         [datum.node_name for datum in dump.dumped_tensor_data])
314      else:
315        self.assertEqual(0, dump.size)
316
317      self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
318      self.assertEqual(repr(None), dump.run_feed_keys_info)
319
320  def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self):
321    watch_fn_state = {"run_counter": 0}
322
323    def counting_watch_fn(fetches, feed_dict):
324      del fetches, feed_dict
325      watch_fn_state["run_counter"] += 1
326      if watch_fn_state["run_counter"] % 2 == 1:
327        # If odd-index run (1-based), watch everything.
328        return "DebugIdentity", r".*", r".*"
329      else:
330        # If even-index run, watch nothing.
331        return "DebugIdentity", r"$^", r"$^"
332
333    dumping_hook = hooks.DumpingDebugHook(
334        self.session_root, watch_fn=counting_watch_fn, log_usage=False)
335    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
336    for _ in range(4):
337      mon_sess.run(self.inc_v)
338
339    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
340    dump_dirs = sorted(
341        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
342    self.assertEqual(4, len(dump_dirs))
343
344    for i, dump_dir in enumerate(dump_dirs):
345      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
346      dump = debug_data.DebugDumpDir(dump_dir)
347      if i % 2 == 0:
348        self.assertAllClose([10.0 + 1.0 * i],
349                            dump.get_tensors("v", 0, "DebugIdentity"))
350      else:
351        self.assertEqual(0, dump.size)
352
353      self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
354      self.assertEqual(repr(None), dump.run_feed_keys_info)
355
356  def testDumpingFromMultipleThreadsObeysThreadNameFilter(self):
357    sess = dumping_wrapper.DumpingDebugWrapperSession(
358        self.sess, session_root=self.session_root, log_usage=False,
359        thread_name_filter=r"MainThread$")
360
361    self.assertAllClose(1.0, sess.run(self.delta))
362    child_thread_result = []
363    def child_thread_job():
364      child_thread_result.append(sess.run(self.eta))
365
366    thread = threading.Thread(name="ChildThread", target=child_thread_job)
367    thread.start()
368    thread.join()
369    self.assertAllClose([-1.4], child_thread_result)
370
371    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
372    self.assertEqual(1, len(dump_dirs))
373    dump = debug_data.DebugDumpDir(dump_dirs[0])
374    self.assertEqual(1, dump.size)
375    self.assertEqual("delta", dump.dumped_tensor_data[0].node_name)
376
377  def testDumpingWrapperWithEmptyFetchWorks(self):
378    sess = dumping_wrapper.DumpingDebugWrapperSession(
379        self.sess, session_root=self.session_root, log_usage=False)
380    sess.run([])
381
382
383if __name__ == "__main__":
384  googletest.main()
385