• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for local command-line-interface debug wrapper session."""
16import os
17import tempfile
18
19import numpy as np
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.core.protobuf import rewriter_config_pb2
23from tensorflow.python.client import session
24from tensorflow.python.debug.cli import cli_config
25from tensorflow.python.debug.cli import cli_shared
26from tensorflow.python.debug.cli import debugger_cli_common
27from tensorflow.python.debug.cli import ui_factory
28from tensorflow.python.debug.wrappers import local_cli_wrapper
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.lib.io import file_io
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38# Import resource_variable_ops for the variables-to-tensor implicit conversion.
39from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
40from tensorflow.python.ops import sparse_ops
41from tensorflow.python.ops import state_ops
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import googletest
44from tensorflow.python.training import monitored_session
45from tensorflow.python.training import session_run_hook
46
47
48class LocalCLIDebuggerWrapperSessionForTest(
49    local_cli_wrapper.LocalCLIDebugWrapperSession):
50  """Subclasses the wrapper class for testing.
51
52  Overrides its CLI-related methods for headless testing environments.
53  Inserts observer variables for assertions.
54  """
55
56  def __init__(self,
57               command_sequence,
58               sess,
59               dump_root=None):
60    """Constructor of the for-test subclass.
61
62    Args:
63      command_sequence: (list of list of str) A list of command arguments,
64        including the command prefix, each element of the list is such as:
65        ["run", "-n"],
66        ["print_feed", "input:0"].
67      sess: See the doc string of LocalCLIDebugWrapperSession.__init__.
68      dump_root: See the doc string of LocalCLIDebugWrapperSession.__init__.
69    """
70
71    local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
72        self, sess, dump_root=dump_root, log_usage=False)
73
74    self._command_sequence = command_sequence
75    self._command_pointer = 0
76
77    # Observer variables.
78    self.observers = {
79        "debug_dumps": [],
80        "tf_errors": [],
81        "run_start_cli_run_numbers": [],
82        "run_end_cli_run_numbers": [],
83        "print_feed_responses": [],
84        "profiler_py_graphs": [],
85        "profiler_run_metadata": [],
86    }
87
88  def _prep_cli_for_run_start(self):
89    pass
90
91  def _prep_debug_cli_for_run_end(self,
92                                  debug_dump,
93                                  tf_error,
94                                  passed_filter,
95                                  passed_filter_exclude_op_names):
96    self.observers["debug_dumps"].append(debug_dump)
97    self.observers["tf_errors"].append(tf_error)
98
99  def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
100    self.observers["profiler_py_graphs"].append(py_graph)
101    self.observers["profiler_run_metadata"].append(run_metadata)
102
103  def _launch_cli(self):
104    if self._is_run_start:
105      self.observers["run_start_cli_run_numbers"].append(self._run_call_count)
106    else:
107      self.observers["run_end_cli_run_numbers"].append(self._run_call_count)
108
109    readline_cli = ui_factory.get_ui(
110        "readline",
111        config=cli_config.CLIConfig(
112            config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config")))
113    self._register_this_run_info(readline_cli)
114
115    while self._command_pointer < len(self._command_sequence):
116      command = self._command_sequence[self._command_pointer]
117      self._command_pointer += 1
118
119      try:
120        if command[0] == "run":
121          self._run_handler(command[1:])
122        elif command[0] == "print_feed":
123          self.observers["print_feed_responses"].append(
124              self._print_feed_handler(command[1:]))
125        else:
126          raise ValueError("Unrecognized command prefix: %s" % command[0])
127      except debugger_cli_common.CommandLineExit as e:
128        return e.exit_token
129
130
131@test_util.run_v1_only("b/120545219")
132class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
133
134  def setUp(self):
135    self._tmp_dir = tempfile.mkdtemp()
136
137    self.v = variables.VariableV1(10.0, name="v")
138    self.w = variables.VariableV1(21.0, name="w")
139    self.delta = constant_op.constant(1.0, name="delta")
140    self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
141
142    self.w_int = control_flow_ops.with_dependencies(
143        [self.inc_v],
144        math_ops.cast(self.w, dtypes.int32, name="w_int_inner"),
145        name="w_int_outer")
146
147    self.ph = array_ops.placeholder(dtypes.float32, name="ph")
148    self.xph = array_ops.transpose(self.ph, name="xph")
149    self.m = constant_op.constant(
150        [[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=dtypes.float32, name="m")
151    self.y = math_ops.matmul(self.m, self.xph, name="y")
152
153    self.sparse_ph = array_ops.sparse_placeholder(
154        dtypes.float32, shape=([5, 5]), name="sparse_placeholder")
155    self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph)
156
157    rewriter_config = rewriter_config_pb2.RewriterConfig(
158        disable_model_pruning=True,
159        arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
160        dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
161    graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
162    config_proto = config_pb2.ConfigProto(graph_options=graph_options)
163    self.sess = session.Session(config=config_proto)
164
165    # Initialize variable.
166    self.sess.run(variables.global_variables_initializer())
167
168  def tearDown(self):
169    ops.reset_default_graph()
170    if os.path.isdir(self._tmp_dir):
171      file_io.delete_recursively(self._tmp_dir)
172
173  def testConstructWrapper(self):
174    local_cli_wrapper.LocalCLIDebugWrapperSession(
175        session.Session(), log_usage=False)
176
177  def testConstructWrapperWithExistingNonEmptyDumpRoot(self):
178    dir_path = os.path.join(self._tmp_dir, "foo")
179    os.mkdir(dir_path)
180    self.assertTrue(os.path.isdir(dir_path))
181
182    with self.assertRaisesRegex(
183        ValueError, "dump_root path points to a non-empty directory"):
184      local_cli_wrapper.LocalCLIDebugWrapperSession(
185          session.Session(), dump_root=self._tmp_dir, log_usage=False)
186
187  def testConstructWrapperWithExistingFileDumpRoot(self):
188    file_path = os.path.join(self._tmp_dir, "foo")
189    open(file_path, "a").close()  # Create the file
190    self.assertTrue(os.path.isfile(file_path))
191    with self.assertRaisesRegex(ValueError, "dump_root path points to a file"):
192      local_cli_wrapper.LocalCLIDebugWrapperSession(
193          session.Session(), dump_root=file_path, log_usage=False)
194
195  def testRunsUnderDebugMode(self):
196    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
197        [["run"], ["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
198
199    # run under debug mode twice.
200    wrapped_sess.run(self.inc_v)
201    wrapped_sess.run(self.inc_v)
202
203    # Verify that the assign_add op did take effect.
204    self.assertAllClose(12.0, self.sess.run(self.v))
205
206    # Assert correct run call numbers for which the CLI has been launched at
207    # run-start and run-end.
208    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
209    self.assertEqual([1, 2], wrapped_sess.observers["run_end_cli_run_numbers"])
210
211    # Verify that the dumps have been generated and picked up during run-end.
212    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
213
214    # Verify that the TensorFlow runtime errors are picked up and in this case,
215    # they should be both None.
216    self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
217
218  def testRunsWithEmptyStringDumpRootWorks(self):
219    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
220        [["run"], ["run"]], self.sess, dump_root="")
221
222    # run under debug mode.
223    wrapped_sess.run(self.inc_v)
224
225    self.assertAllClose(11.0, self.sess.run(self.v))
226
227  def testRunInfoOutputAtRunEndIsCorrect(self):
228    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
229        [["run"], ["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
230
231    wrapped_sess.run(self.inc_v)
232    run_info_output = wrapped_sess._run_info_handler([])
233
234    tfdbg_logo = cli_shared.get_tfdbg_logo()
235
236    # The run_info output in the first run() call should contain the tfdbg logo.
237    self.assertEqual(tfdbg_logo.lines,
238                     run_info_output.lines[:len(tfdbg_logo.lines)])
239    menu = run_info_output.annotations[debugger_cli_common.MAIN_MENU_KEY]
240    self.assertIn("list_tensors", menu.captions())
241
242    wrapped_sess.run(self.inc_v)
243    run_info_output = wrapped_sess._run_info_handler([])
244
245    # The run_info output in the second run() call should NOT contain the logo.
246    self.assertNotEqual(tfdbg_logo.lines,
247                        run_info_output.lines[:len(tfdbg_logo.lines)])
248    menu = run_info_output.annotations[debugger_cli_common.MAIN_MENU_KEY]
249    self.assertIn("list_tensors", menu.captions())
250
251  def testRunsUnderNonDebugMode(self):
252    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
253        [["run", "-n"], ["run", "-n"], ["run", "-n"]],
254        self.sess, dump_root=self._tmp_dir)
255
256    # run three times.
257    wrapped_sess.run(self.inc_v)
258    wrapped_sess.run(self.inc_v)
259    wrapped_sess.run(self.inc_v)
260
261    self.assertAllClose(13.0, self.sess.run(self.v))
262
263    self.assertEqual([1, 2, 3],
264                     wrapped_sess.observers["run_start_cli_run_numbers"])
265    self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
266
267  def testRunningWithSparsePlaceholderFeedWorks(self):
268    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
269        [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
270
271    sparse_feed = ([[0, 1], [0, 2]], [10.0, 20.0])
272    sparse_result = wrapped_sess.run(
273        self.sparse_add, feed_dict={self.sparse_ph: sparse_feed})
274    self.assertAllEqual([[0, 1], [0, 2]], sparse_result.indices)
275    self.assertAllClose([20.0, 40.0], sparse_result.values)
276
277  def testRunsUnderNonDebugThenDebugMode(self):
278    # Do two NON_DEBUG_RUNs, followed by DEBUG_RUNs.
279    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
280        [["run", "-n"], ["run", "-n"], ["run"], ["run"]],
281        self.sess, dump_root=self._tmp_dir)
282
283    # run three times.
284    wrapped_sess.run(self.inc_v)
285    wrapped_sess.run(self.inc_v)
286    wrapped_sess.run(self.inc_v)
287
288    self.assertAllClose(13.0, self.sess.run(self.v))
289
290    self.assertEqual([1, 2, 3],
291                     wrapped_sess.observers["run_start_cli_run_numbers"])
292
293    # Here, the CLI should have been launched only under the third run,
294    # because the first and second runs are NON_DEBUG.
295    self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])
296    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
297    self.assertEqual([None], wrapped_sess.observers["tf_errors"])
298
299  def testRunMultipleTimesWithinLimit(self):
300    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
301        [["run", "-t", "3"], ["run"]],
302        self.sess, dump_root=self._tmp_dir)
303
304    # run three times.
305    wrapped_sess.run(self.inc_v)
306    wrapped_sess.run(self.inc_v)
307    wrapped_sess.run(self.inc_v)
308
309    self.assertAllClose(13.0, self.sess.run(self.v))
310
311    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
312    self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])
313    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
314    self.assertEqual([None], wrapped_sess.observers["tf_errors"])
315
316  def testRunMultipleTimesOverLimit(self):
317    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
318        [["run", "-t", "3"]], self.sess, dump_root=self._tmp_dir)
319
320    # run twice, which is less than the number of times specified by the
321    # command.
322    wrapped_sess.run(self.inc_v)
323    wrapped_sess.run(self.inc_v)
324
325    self.assertAllClose(12.0, self.sess.run(self.v))
326
327    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
328    self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
329    self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))
330    self.assertEqual([], wrapped_sess.observers["tf_errors"])
331
332  def testRunMixingDebugModeAndMultipleTimes(self):
333    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
334        [["run", "-n"], ["run", "-t", "2"], ["run"], ["run"]],
335        self.sess, dump_root=self._tmp_dir)
336
337    # run four times.
338    wrapped_sess.run(self.inc_v)
339    wrapped_sess.run(self.inc_v)
340    wrapped_sess.run(self.inc_v)
341    wrapped_sess.run(self.inc_v)
342
343    self.assertAllClose(14.0, self.sess.run(self.v))
344
345    self.assertEqual([1, 2],
346                     wrapped_sess.observers["run_start_cli_run_numbers"])
347    self.assertEqual([3, 4], wrapped_sess.observers["run_end_cli_run_numbers"])
348    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
349    self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
350
351  def testDebuggingMakeCallableTensorRunnerWorks(self):
352    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
353        [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
354    v = variables.VariableV1(42)
355    tensor_runner = wrapped_sess.make_callable(v)
356    self.sess.run(v.initializer)
357
358    self.assertAllClose(42, tensor_runner())
359    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
360
361  def testDebuggingMakeCallableTensorRunnerWithCustomRunOptionsWorks(self):
362    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
363        [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
364    a = constant_op.constant(42)
365    tensor_runner = wrapped_sess.make_callable(a)
366
367    run_options = config_pb2.RunOptions(
368        trace_level=config_pb2.RunOptions.FULL_TRACE)
369    run_metadata = config_pb2.RunMetadata()
370    self.assertAllClose(
371        42, tensor_runner(options=run_options, run_metadata=run_metadata))
372    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
373    self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
374
375  def testDebuggingMakeCallableOperationRunnerWorks(self):
376    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
377        [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
378    v = variables.VariableV1(10.0)
379    inc_v = state_ops.assign_add(v, 1.0)
380    op_runner = wrapped_sess.make_callable(inc_v.op)
381    self.sess.run(v.initializer)
382
383    op_runner()
384    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
385    self.assertEqual(11.0, self.sess.run(v))
386
387  def testDebuggingMakeCallableRunnerWithFeedListWorks(self):
388    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
389        [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
390    ph1 = array_ops.placeholder(dtypes.float32)
391    ph2 = array_ops.placeholder(dtypes.float32)
392    a = math_ops.add(ph1, ph2)
393    tensor_runner = wrapped_sess.make_callable(a, feed_list=[ph1, ph2])
394
395    self.assertAllClose(42.0, tensor_runner(41.0, 1.0))
396    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
397
398  def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self):
399    variable_1 = variables.VariableV1(
400        10.5, dtype=dtypes.float32, name="variable_1")
401    a = math_ops.add(variable_1, variable_1, "callable_a")
402    math_ops.add(a, a, "callable_b")
403    self.sess.run(variable_1.initializer)
404
405    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
406        [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
407    callable_options = config_pb2.CallableOptions()
408    callable_options.fetch.append("callable_b")
409    sess_callable = wrapped_sess._make_callable_from_options(callable_options)
410
411    for _ in range(2):
412      callable_output = sess_callable()
413      self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
414
415    debug_dumps = wrapped_sess.observers["debug_dumps"]
416    self.assertEqual(2, len(debug_dumps))
417    for debug_dump in debug_dumps:
418      node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
419      self.assertItemsEqual(
420          ["callable_a", "callable_b", "variable_1", "variable_1/read"],
421          node_names)
422
423  def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self):
424    ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
425    a = math_ops.add(ph1, ph1, "callable_a")
426    math_ops.add(a, a, "callable_b")
427
428    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
429        [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
430    callable_options = config_pb2.CallableOptions()
431    callable_options.feed.append("callable_ph1")
432    callable_options.fetch.append("callable_b")
433    sess_callable = wrapped_sess._make_callable_from_options(callable_options)
434
435    ph1_value = np.array([10.5, -10.5], dtype=np.float32)
436
437    for _ in range(2):
438      callable_output = sess_callable(ph1_value)
439      self.assertAllClose(
440          np.array([42.0, -42.0], dtype=np.float32), callable_output[0])
441
442    debug_dumps = wrapped_sess.observers["debug_dumps"]
443    self.assertEqual(2, len(debug_dumps))
444    for debug_dump in debug_dumps:
445      node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
446      self.assertIn("callable_a", node_names)
447      self.assertIn("callable_b", node_names)
448
449  def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self):
450    ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
451    ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2")
452    a = math_ops.add(ph1, ph2, "callable_a")
453    math_ops.add(a, a, "callable_b")
454
455    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
456        [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
457    callable_options = config_pb2.CallableOptions()
458    callable_options.feed.append("callable_ph1")
459    callable_options.feed.append("callable_ph2")
460    callable_options.fetch.append("callable_b")
461    sess_callable = wrapped_sess._make_callable_from_options(callable_options)
462
463    ph1_value = np.array(5.0, dtype=np.float32)
464    ph2_value = np.array(16.0, dtype=np.float32)
465
466    for _ in range(2):
467      callable_output = sess_callable(ph1_value, ph2_value)
468      self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
469
470    debug_dumps = wrapped_sess.observers["debug_dumps"]
471    self.assertEqual(2, len(debug_dumps))
472    for debug_dump in debug_dumps:
473      node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
474      self.assertIn("callable_a", node_names)
475      self.assertIn("callable_b", node_names)
476
477  def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
478    variable_1 = variables.VariableV1(
479        10.5, dtype=dtypes.float32, name="variable_1")
480    a = math_ops.add(variable_1, variable_1, "callable_a")
481    math_ops.add(a, a, "callable_b")
482    self.sess.run(variable_1.initializer)
483
484    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
485        [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
486    callable_options = config_pb2.CallableOptions()
487    callable_options.fetch.append("callable_b")
488    callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
489
490    sess_callable = wrapped_sess._make_callable_from_options(callable_options)
491
492    run_metadata = config_pb2.RunMetadata()
493    # Call the callable with a custom run_metadata.
494    callable_output = sess_callable(run_metadata=run_metadata)
495    # Verify that step_stats is populated in the custom run_metadata.
496    self.assertTrue(run_metadata.step_stats)
497    self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
498
499    debug_dumps = wrapped_sess.observers["debug_dumps"]
500    self.assertEqual(1, len(debug_dumps))
501    debug_dump = debug_dumps[0]
502    node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
503    self.assertItemsEqual(
504        ["callable_a", "callable_b", "variable_1", "variable_1/read"],
505        node_names)
506
507  def testRuntimeErrorShouldBeCaught(self):
508    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
509        [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
510
511    # Do a run that should lead to an TensorFlow runtime error.
512    wrapped_sess.run(self.y, feed_dict={self.ph: [[0.0], [1.0], [2.0]]})
513
514    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
515    self.assertEqual([1], wrapped_sess.observers["run_end_cli_run_numbers"])
516    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
517
518    # Verify that the runtime error is caught by the wrapped session properly.
519    self.assertEqual(1, len(wrapped_sess.observers["tf_errors"]))
520    tf_error = wrapped_sess.observers["tf_errors"][0]
521    self.assertEqual("y", tf_error.op.name)
522
523  def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self):
524    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
525        [["run", "-f", "v_greater_than_twelve"],
526         ["run", "-f", "v_greater_than_twelve"],
527         ["run"]],
528        self.sess,
529        dump_root=self._tmp_dir)
530
531    def v_greater_than_twelve(datum, tensor):
532      return datum.node_name == "v" and tensor > 12.0
533
534    # Verify that adding the same tensor filter more than once is tolerated
535    # (i.e., as if it were added only once).
536    wrapped_sess.add_tensor_filter("v_greater_than_twelve",
537                                   v_greater_than_twelve)
538    wrapped_sess.add_tensor_filter("v_greater_than_twelve",
539                                   v_greater_than_twelve)
540
541    # run five times.
542    wrapped_sess.run(self.inc_v)
543    wrapped_sess.run(self.inc_v)
544    wrapped_sess.run(self.inc_v)
545    wrapped_sess.run(self.inc_v)
546    wrapped_sess.run(self.inc_v)
547
548    self.assertAllClose(15.0, self.sess.run(self.v))
549
550    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
551
552    # run-end CLI should NOT have been launched for run #2 and #3, because only
553    # starting from run #4 v becomes greater than 12.0.
554    self.assertEqual([4, 5], wrapped_sess.observers["run_end_cli_run_numbers"])
555
556    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
557    self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
558
559  def testRunTillFilterPassesWithExcludeOpNames(self):
560    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
561        [["run", "-f", "greater_than_twelve",
562          "--filter_exclude_node_names", "inc_v.*"],
563         ["run"], ["run"]],
564        self.sess,
565        dump_root=self._tmp_dir)
566
567    def greater_than_twelve(datum, tensor):
568      del datum  # Unused.
569      return tensor > 12.0
570
571    # Verify that adding the same tensor filter more than once is tolerated
572    # (i.e., as if it were added only once).
573    wrapped_sess.add_tensor_filter("greater_than_twelve", greater_than_twelve)
574
575    # run five times.
576    wrapped_sess.run(self.inc_v)
577    wrapped_sess.run(self.inc_v)
578    wrapped_sess.run(self.inc_v)
579    wrapped_sess.run(self.inc_v)
580
581    self.assertAllClose(14.0, self.sess.run(self.v))
582
583    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
584
585    # Due to the --filter_exclude_op_names flag, the run-end CLI should show up
586    # not after run 3, but after run 4.
587    self.assertEqual([4], wrapped_sess.observers["run_end_cli_run_numbers"])
588
589  def testRunTillFilterPassesWorksInConjunctionWithOtherNodeNameFilter(self):
590    """Test that --.*_filter flags work in conjunction with -f.
591
592    In other words, test that you can use a tensor filter on a subset of
593    the tensors.
594    """
595    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
596        [["run", "-f", "v_greater_than_twelve", "--node_name_filter", "v$"],
597         ["run", "-f", "v_greater_than_twelve", "--node_name_filter", "v$"],
598         ["run"]],
599        self.sess,
600        dump_root=self._tmp_dir)
601
602    def v_greater_than_twelve(datum, tensor):
603      return datum.node_name == "v" and tensor > 12.0
604    wrapped_sess.add_tensor_filter("v_greater_than_twelve",
605                                   v_greater_than_twelve)
606
607    # run five times.
608    wrapped_sess.run(self.inc_v)
609    wrapped_sess.run(self.inc_v)
610    wrapped_sess.run(self.inc_v)
611    wrapped_sess.run(self.inc_v)
612    wrapped_sess.run(self.inc_v)
613
614    self.assertAllClose(15.0, self.sess.run(self.v))
615
616    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
617
618    # run-end CLI should NOT have been launched for run #2 and #3, because only
619    # starting from run #4 v becomes greater than 12.0.
620    self.assertEqual([4, 5], wrapped_sess.observers["run_end_cli_run_numbers"])
621
622    debug_dumps = wrapped_sess.observers["debug_dumps"]
623    self.assertEqual(2, len(debug_dumps))
624    self.assertEqual(1, len(debug_dumps[0].dumped_tensor_data))
625    self.assertEqual("v:0", debug_dumps[0].dumped_tensor_data[0].tensor_name)
626    self.assertEqual(1, len(debug_dumps[1].dumped_tensor_data))
627    self.assertEqual("v:0", debug_dumps[1].dumped_tensor_data[0].tensor_name)
628
629  def testRunsUnderDebugModeWithWatchFnFilteringNodeNames(self):
630    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
631        [["run", "--node_name_filter", "inc.*"],
632         ["run", "--node_name_filter", "delta"],
633         ["run"]],
634        self.sess, dump_root=self._tmp_dir)
635
636    # run under debug mode twice.
637    wrapped_sess.run(self.inc_v)
638    wrapped_sess.run(self.inc_v)
639
640    # Verify that the assign_add op did take effect.
641    self.assertAllClose(12.0, self.sess.run(self.v))
642
643    # Verify that the dumps have been generated and picked up during run-end.
644    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
645
646    dumps = wrapped_sess.observers["debug_dumps"][0]
647    self.assertEqual(1, dumps.size)
648    self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name)
649
650    dumps = wrapped_sess.observers["debug_dumps"][1]
651    self.assertEqual(1, dumps.size)
652    self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name)
653
654  def testRunsUnderDebugModeWithWatchFnFilteringOpTypes(self):
655    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
656        [["run", "--node_name_filter", "delta"],
657         ["run", "--op_type_filter", "AssignAdd"],
658         ["run"]],
659        self.sess, dump_root=self._tmp_dir)
660
661    # run under debug mode twice.
662    wrapped_sess.run(self.inc_v)
663    wrapped_sess.run(self.inc_v)
664
665    # Verify that the assign_add op did take effect.
666    self.assertAllClose(12.0, self.sess.run(self.v))
667
668    # Verify that the dumps have been generated and picked up during run-end.
669    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
670
671    dumps = wrapped_sess.observers["debug_dumps"][0]
672    self.assertEqual(1, dumps.size)
673    self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name)
674
675    dumps = wrapped_sess.observers["debug_dumps"][1]
676    self.assertEqual(1, dumps.size)
677    self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name)
678
679  def testRunsUnderDebugModeWithWatchFnFilteringTensorDTypes(self):
680    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
681        [["run", "--op_type_filter", "Variable.*"],
682         ["run", "--tensor_dtype_filter", "int32"],
683         ["run"]],
684        self.sess, dump_root=self._tmp_dir)
685
686    # run under debug mode twice.
687    wrapped_sess.run(self.w_int)
688    wrapped_sess.run(self.w_int)
689
690    # Verify that the dumps have been generated and picked up during run-end.
691    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
692
693    dumps = wrapped_sess.observers["debug_dumps"][0]
694    self.assertEqual(2, dumps.size)
695    self.assertItemsEqual(
696        ["v", "w"], [dumps.dumped_tensor_data[i].node_name for i in [0, 1]])
697
698    dumps = wrapped_sess.observers["debug_dumps"][1]
699    self.assertEqual(2, dumps.size)
700    self.assertEqual(
701        ["w_int_inner", "w_int_outer"],
702        [dumps.dumped_tensor_data[i].node_name for i in [0, 1]])
703
704  def testRunsUnderDebugModeWithWatchFnFilteringOpTypesAndTensorDTypes(self):
705    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
706        [["run", "--op_type_filter", "Cast", "--tensor_dtype_filter", "int32"],
707         ["run"]],
708        self.sess, dump_root=self._tmp_dir)
709
710    # run under debug mode twice.
711    wrapped_sess.run(self.w_int)
712
713    # Verify that the dumps have been generated and picked up during run-end.
714    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
715
716    dumps = wrapped_sess.observers["debug_dumps"][0]
717    self.assertEqual(1, dumps.size)
718    self.assertEqual("w_int_inner", dumps.dumped_tensor_data[0].node_name)
719
720  def testPrintFeedPrintsFeedValueForTensorFeedKey(self):
721    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
722        [["print_feed", "ph:0"], ["run"], ["run"]], self.sess)
723
724    self.assertAllClose(
725        [[5.0], [-1.0]],
726        wrapped_sess.run(self.y, feed_dict={self.ph: [[0.0, 1.0, 2.0]]}))
727    print_feed_responses = wrapped_sess.observers["print_feed_responses"]
728    self.assertEqual(1, len(print_feed_responses))
729    self.assertEqual(
730        ["Tensor \"ph:0 (feed)\":", "", "[[0.0, 1.0, 2.0]]"],
731        print_feed_responses[0].lines)
732
733  def testPrintFeedPrintsFeedValueForTensorNameFeedKey(self):
734    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
735        [["print_feed", "ph:0"], ["run"], ["run"]], self.sess)
736
737    self.assertAllClose(
738        [[5.0], [-1.0]],
739        wrapped_sess.run(self.y, feed_dict={"ph:0": [[0.0, 1.0, 2.0]]}))
740    print_feed_responses = wrapped_sess.observers["print_feed_responses"]
741    self.assertEqual(1, len(print_feed_responses))
742    self.assertEqual(
743        ["Tensor \"ph:0 (feed)\":", "", "[[0.0, 1.0, 2.0]]"],
744        print_feed_responses[0].lines)
745
746  def testPrintFeedPrintsErrorForInvalidFeedKey(self):
747    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
748        [["print_feed", "spam"], ["run"], ["run"]], self.sess)
749
750    self.assertAllClose(
751        [[5.0], [-1.0]],
752        wrapped_sess.run(self.y, feed_dict={"ph:0": [[0.0, 1.0, 2.0]]}))
753    print_feed_responses = wrapped_sess.observers["print_feed_responses"]
754    self.assertEqual(1, len(print_feed_responses))
755    self.assertEqual(
756        ["ERROR: The feed_dict of the current run does not contain the key "
757         "spam"], print_feed_responses[0].lines)
758
759  def testPrintFeedPrintsErrorWhenFeedDictIsNone(self):
760    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
761        [["print_feed", "spam"], ["run"], ["run"]], self.sess)
762
763    wrapped_sess.run(self.w_int)
764    print_feed_responses = wrapped_sess.observers["print_feed_responses"]
765    self.assertEqual(1, len(print_feed_responses))
766    self.assertEqual(
767        ["ERROR: The feed_dict of the current run is None or empty."],
768        print_feed_responses[0].lines)
769
770  def testRunUnderProfilerModeWorks(self):
771    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
772        [["run", "-p"], ["run"]], self.sess)
773
774    wrapped_sess.run(self.w_int)
775
776    self.assertEqual(1, len(wrapped_sess.observers["profiler_run_metadata"]))
777    self.assertTrue(
778        wrapped_sess.observers["profiler_run_metadata"][0].step_stats)
779    self.assertEqual(1, len(wrapped_sess.observers["profiler_py_graphs"]))
780    self.assertIsInstance(
781        wrapped_sess.observers["profiler_py_graphs"][0], ops.Graph)
782
783  def testCallingHookDelBeforeAnyRun(self):
784    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
785        [["run"], ["run"]], self.sess)
786    del wrapped_sess
787
788  def testCallingShouldStopMethodOnNonWrappedNonMonitoredSessionErrors(self):
789    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
790        [["run"], ["run"]], self.sess)
791    with self.assertRaisesRegex(
792        ValueError,
793        r"The wrapped session .* does not have a method .*should_stop.*"):
794      wrapped_sess.should_stop()
795
796  def testLocalCLIDebugWrapperSessionWorksOnMonitoredSession(self):
797    monitored_sess = monitored_session.MonitoredSession()
798    wrapped_monitored_sess = LocalCLIDebuggerWrapperSessionForTest(
799        [["run"], ["run"]], monitored_sess)
800    self.assertFalse(wrapped_monitored_sess.should_stop())
801
802  def testRunsWithEmptyFetchWorks(self):
803    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
804        [["run"]], self.sess, dump_root="")
805
806    run_output = wrapped_sess.run([])
807    self.assertEqual([], run_output)
808
809  def testRunsWithEmptyNestedFetchWorks(self):
810    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
811        [["run"]], self.sess, dump_root="")
812
813    run_output = wrapped_sess.run({"foo": {"baz": []}, "bar": ()})
814    self.assertEqual({"foo": {"baz": []}, "bar": ()}, run_output)
815
816  def testSessionRunHook(self):
817    a = array_ops.placeholder(dtypes.float32, [10])
818    b = a + 1
819    c = b * 2
820
821    class Hook(session_run_hook.SessionRunHook):
822
823      def before_run(self, _):
824        return session_run_hook.SessionRunArgs(fetches=c)
825
826    class Hook2(session_run_hook.SessionRunHook):
827
828      def before_run(self, _):
829        return session_run_hook.SessionRunArgs(fetches=b)
830
831    sess = session.Session()
832    sess = LocalCLIDebuggerWrapperSessionForTest([["run"], ["run"]], sess)
833
834    class SessionCreator(object):
835
836      def create_session(self):
837        return sess
838
839    final_sess = monitored_session.MonitoredSession(
840        session_creator=SessionCreator(), hooks=[Hook(), Hook2()])
841
842    final_sess.run(b, feed_dict={a: np.arange(10)})
843    debug_dumps = sess.observers["debug_dumps"]
844    self.assertEqual(1, len(debug_dumps))
845    debug_dump = debug_dumps[0]
846    node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
847    self.assertIn(b.op.name, node_names)
848
849
850if __name__ == "__main__":
851  googletest.main()
852