• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Framework of debug-wrapped sessions."""
16import os
17import tempfile
18import threading
19
20import numpy as np
21
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.core.protobuf import rewriter_config_pb2
24from tensorflow.python.client import session
25from tensorflow.python.debug.lib import debug_data
26from tensorflow.python.debug.wrappers import framework
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import test_util
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import math_ops
35# Import resource_variable_ops for the variables-to-tensor implicit conversion.
36from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
37from tensorflow.python.ops import variables
38from tensorflow.python.platform import googletest
39from tensorflow.python.training import monitored_session
40from tensorflow.python.util import tf_inspect
41
42
43class TestDebugWrapperSession(framework.BaseDebugWrapperSession):
44  """A concrete implementation of BaseDebugWrapperSession for test."""
45
46  def __init__(self, sess, dump_root, observer, thread_name_filter=None):
47    # Supply dump root.
48    self._dump_root = dump_root
49
50    # Supply observer.
51    self._obs = observer
52
53    # Invoke superclass constructor.
54    framework.BaseDebugWrapperSession.__init__(
55        self, sess, thread_name_filter=thread_name_filter)
56
57  def on_session_init(self, request):
58    """Override abstract on-session-init callback method."""
59
60    self._obs["sess_init_count"] += 1
61    self._obs["request_sess"] = request.session
62
63    return framework.OnSessionInitResponse(
64        framework.OnSessionInitAction.PROCEED)
65
66  def on_run_start(self, request):
67    """Override abstract on-run-start callback method."""
68
69    self._obs["on_run_start_count"] += 1
70    self._obs["run_fetches"] = request.fetches
71    self._obs["run_feed_dict"] = request.feed_dict
72
73    return framework.OnRunStartResponse(
74        framework.OnRunStartAction.DEBUG_RUN,
75        ["file://" + self._dump_root])
76
77  def on_run_end(self, request):
78    """Override abstract on-run-end callback method."""
79
80    self._obs["on_run_end_count"] += 1
81    self._obs["performed_action"] = request.performed_action
82    self._obs["tf_error"] = request.tf_error
83
84    return framework.OnRunEndResponse()
85
86
87class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession):
88  """A concrete implementation of BaseDebugWrapperSession for test.
89
90  This class intentionally puts a bad action value in OnSessionInitResponse
91  and/or in OnRunStartAction to test the handling of such invalid cases.
92  """
93
94  def __init__(
95      self,
96      sess,
97      bad_init_action=None,
98      bad_run_start_action=None,
99      bad_debug_urls=None):
100    """Constructor.
101
102    Args:
103      sess: The TensorFlow Session object to be wrapped.
104      bad_init_action: (str) bad action value to be returned during the
105        on-session-init callback.
106      bad_run_start_action: (str) bad action value to be returned during the
107        the on-run-start callback.
108      bad_debug_urls: Bad URL values to be returned during the on-run-start
109        callback.
110    """
111
112    self._bad_init_action = bad_init_action
113    self._bad_run_start_action = bad_run_start_action
114    self._bad_debug_urls = bad_debug_urls
115
116    # Invoke superclass constructor.
117    framework.BaseDebugWrapperSession.__init__(self, sess)
118
119  def on_session_init(self, request):
120    if self._bad_init_action:
121      return framework.OnSessionInitResponse(self._bad_init_action)
122    else:
123      return framework.OnSessionInitResponse(
124          framework.OnSessionInitAction.PROCEED)
125
126  def on_run_start(self, request):
127    debug_urls = self._bad_debug_urls or []
128
129    if self._bad_run_start_action:
130      return framework.OnRunStartResponse(
131          self._bad_run_start_action, debug_urls)
132    else:
133      return framework.OnRunStartResponse(
134          framework.OnRunStartAction.DEBUG_RUN, debug_urls)
135
136  def on_run_end(self, request):
137    return framework.OnRunEndResponse()
138
139
140@test_util.run_v1_only("Sessions are not available in TF 2.x")
141class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
142
143  def _no_rewrite_session_config(self):
144    rewriter_config = rewriter_config_pb2.RewriterConfig(
145        disable_model_pruning=True)
146    graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
147    return config_pb2.ConfigProto(graph_options=graph_options)
148
149  def setUp(self):
150    self._observer = {
151        "sess_init_count": 0,
152        "request_sess": None,
153        "on_run_start_count": 0,
154        "run_fetches": None,
155        "run_feed_dict": None,
156        "on_run_end_count": 0,
157        "performed_action": None,
158        "tf_error": None,
159    }
160
161    self._dump_root = tempfile.mkdtemp()
162
163    self._sess = session.Session(config=self._no_rewrite_session_config())
164
165    self._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
166    self._b_init_val = np.array([[2.0], [-1.0]])
167    self._c_val = np.array([[-4.0], [6.0]])
168
169    self._a_init = constant_op.constant(
170        self._a_init_val, shape=[2, 2], name="a_init")
171    self._b_init = constant_op.constant(
172        self._b_init_val, shape=[2, 1], name="b_init")
173
174    self._ph = array_ops.placeholder(dtype=dtypes.float64, name="ph")
175
176    self._a = variables.Variable(self._a_init, name="a1")
177    self._b = variables.Variable(self._b_init, name="b")
178    self._c = constant_op.constant(self._c_val, shape=[2, 1], name="c")
179
180    # Matrix product of a and b.
181    self._p = math_ops.matmul(self._a, self._b, name="p1")
182
183    # Matrix product of a and ph.
184    self._q = math_ops.matmul(self._a, self._ph, name="q")
185
186    # Sum of two vectors.
187    self._s = math_ops.add(self._p, self._c, name="s")
188
189    # Initialize the variables.
190    self._sess.run(self._a.initializer)
191    self._sess.run(self._b.initializer)
192
193  def tearDown(self):
194    # Tear down temporary dump directory.
195    if os.path.isdir(self._dump_root):
196      file_io.delete_recursively(self._dump_root)
197
198    ops.reset_default_graph()
199
200  def testSessionInit(self):
201    self.assertEqual(0, self._observer["sess_init_count"])
202
203    wrapper_sess = TestDebugWrapperSession(self._sess, self._dump_root,
204                                           self._observer)
205
206    # Assert that on-session-init callback is invoked.
207    self.assertEqual(1, self._observer["sess_init_count"])
208
209    # Assert that the request to the on-session-init callback carries the
210    # correct session object.
211    self.assertEqual(self._sess, self._observer["request_sess"])
212
213    # Verify that the wrapper session implements the session.SessionInterface.
214    self.assertTrue(isinstance(wrapper_sess, session.SessionInterface))
215    self.assertEqual(self._sess.sess_str, wrapper_sess.sess_str)
216    self.assertEqual(self._sess.graph, wrapper_sess.graph)
217    self.assertEqual(self._sess.graph_def, wrapper_sess.graph_def)
218
219    # Check that the partial_run_setup and partial_run are not implemented for
220    # the debug wrapper session.
221    with self.assertRaises(NotImplementedError):
222      wrapper_sess.partial_run_setup(self._p)
223
224  def testInteractiveSessionInit(self):
225    """The wrapper should work also on other subclasses of session.Session."""
226
227    TestDebugWrapperSession(
228        session.InteractiveSession(), self._dump_root, self._observer)
229
230  def testSessionRun(self):
231    wrapper = TestDebugWrapperSession(
232        self._sess, self._dump_root, self._observer)
233
234    # Check initial state of the observer.
235    self.assertEqual(0, self._observer["on_run_start_count"])
236    self.assertEqual(0, self._observer["on_run_end_count"])
237
238    s = wrapper.run(self._s)
239
240    # Assert the run return value is correct.
241    self.assertAllClose(np.array([[3.0], [4.0]]), s)
242
243    # Assert the on-run-start method is invoked.
244    self.assertEqual(1, self._observer["on_run_start_count"])
245
246    # Assert the on-run-start request reflects the correct fetch.
247    self.assertEqual(self._s, self._observer["run_fetches"])
248
249    # Assert the on-run-start request reflects the correct feed_dict.
250    self.assertIsNone(self._observer["run_feed_dict"])
251
252    # Assert the file debug URL has led to dump on the filesystem.
253    dump = debug_data.DebugDumpDir(self._dump_root)
254    self.assertEqual(7, len(dump.dumped_tensor_data))
255
256    # Assert the on-run-end method is invoked.
257    self.assertEqual(1, self._observer["on_run_end_count"])
258
259    # Assert the performed action field in the on-run-end callback request is
260    # correct.
261    self.assertEqual(
262        framework.OnRunStartAction.DEBUG_RUN,
263        self._observer["performed_action"])
264
265    # No TensorFlow runtime error should have happened.
266    self.assertIsNone(self._observer["tf_error"])
267
268  def testSessionInitInvalidSessionType(self):
269    """Attempt to wrap a non-Session-type object should cause an exception."""
270
271    wrapper = TestDebugWrapperSessionBadAction(self._sess)
272    with self.assertRaisesRegex(TypeError, "Expected type .*; got type .*"):
273      TestDebugWrapperSessionBadAction(wrapper)
274
275  def testSessionInitBadActionValue(self):
276    with self.assertRaisesRegex(
277        ValueError, "Invalid OnSessionInitAction value: nonsense_action"):
278      TestDebugWrapperSessionBadAction(
279          self._sess, bad_init_action="nonsense_action")
280
281  def testRunStartBadActionValue(self):
282    wrapper = TestDebugWrapperSessionBadAction(
283        self._sess, bad_run_start_action="nonsense_action")
284
285    with self.assertRaisesRegex(
286        ValueError, "Invalid OnRunStartAction value: nonsense_action"):
287      wrapper.run(self._s)
288
289  def testRunStartBadURLs(self):
290    # debug_urls ought to be a list of str, not a str. So an exception should
291    # be raised during a run() call.
292    wrapper = TestDebugWrapperSessionBadAction(
293        self._sess, bad_debug_urls="file://foo")
294
295    with self.assertRaisesRegex(TypeError, "Expected type .*; got type .*"):
296      wrapper.run(self._s)
297
298  def testErrorDuringRun(self):
299
300    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
301                                      self._observer)
302
303    # No matrix size mismatch.
304    self.assertAllClose(
305        np.array([[11.0], [-1.0]]),
306        wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])}))
307    self.assertEqual(1, self._observer["on_run_end_count"])
308    self.assertIsNone(self._observer["tf_error"])
309
310    # Now there should be a matrix size mismatch error.
311    wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0], [3.0]])})
312    self.assertEqual(2, self._observer["on_run_end_count"])
313    self.assertTrue(
314        isinstance(self._observer["tf_error"], errors.InvalidArgumentError))
315
316  def testUsingWrappedSessionShouldWorkAsContextManager(self):
317    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
318                                      self._observer)
319
320    with wrapper as sess:
321      self.assertAllClose([[3.0], [4.0]], self._s)
322      self.assertEqual(1, self._observer["on_run_start_count"])
323      self.assertEqual(self._s, self._observer["run_fetches"])
324      self.assertEqual(1, self._observer["on_run_end_count"])
325
326      self.assertAllClose(
327          [[11.0], [-1.0]],
328          sess.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])}))
329      self.assertEqual(2, self._observer["on_run_start_count"])
330      self.assertEqual(self._q, self._observer["run_fetches"])
331      self.assertEqual(2, self._observer["on_run_end_count"])
332
333  def testUsingWrappedSessionShouldSupportEvalWithAsDefault(self):
334    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
335                                      self._observer)
336
337    with wrapper.as_default():
338      foo = constant_op.constant(42, name="foo")
339      self.assertEqual(42, self.evaluate(foo))
340      self.assertEqual(foo, self._observer["run_fetches"])
341
342  def testWrapperShouldSupportSessionClose(self):
343    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
344                                      self._observer)
345    wrapper.close()
346
347  def testWrapperThreadNameFilterMainThread(self):
348    wrapper = TestDebugWrapperSession(
349        self._sess, self._dump_root, self._observer,
350        thread_name_filter="MainThread")
351
352    child_run_output = []
353    def child_thread_job():
354      child_run_output.append(wrapper.run(self._b_init))
355
356    thread = threading.Thread(name="ChildThread", target=child_thread_job)
357    thread.start()
358    self.assertAllClose(self._a_init_val, wrapper.run(self._a_init))
359    thread.join()
360    self.assertAllClose([self._b_init_val], child_run_output)
361
362    dump = debug_data.DebugDumpDir(self._dump_root)
363    self.assertEqual(1, dump.size)
364    self.assertEqual("a_init", dump.dumped_tensor_data[0].node_name)
365
366  def testWrapperThreadNameFilterChildThread(self):
367    wrapper = TestDebugWrapperSession(
368        self._sess, self._dump_root, self._observer,
369        thread_name_filter=r"Child.*")
370
371    child_run_output = []
372    def child_thread_job():
373      child_run_output.append(wrapper.run(self._b_init))
374
375    thread = threading.Thread(name="ChildThread", target=child_thread_job)
376    thread.start()
377    self.assertAllClose(self._a_init_val, wrapper.run(self._a_init))
378    thread.join()
379    self.assertAllClose([self._b_init_val], child_run_output)
380
381    dump = debug_data.DebugDumpDir(self._dump_root)
382    self.assertEqual(1, dump.size)
383    self.assertEqual("b_init", dump.dumped_tensor_data[0].node_name)
384
385  def testWrapperThreadNameFilterBothThreads(self):
386    wrapper = TestDebugWrapperSession(
387        self._sess, self._dump_root, self._observer,
388        thread_name_filter=None)
389
390    child_run_output = []
391    def child_thread_job():
392      child_run_output.append(wrapper.run(self._b_init))
393
394    thread = threading.Thread(name="ChildThread", target=child_thread_job)
395    thread.start()
396    self.assertAllClose(self._a_init_val, wrapper.run(self._a_init))
397    thread.join()
398    self.assertAllClose([self._b_init_val], child_run_output)
399
400    dump = debug_data.DebugDumpDir(self._dump_root, validate=False)
401    self.assertEqual(2, dump.size)
402    self.assertItemsEqual(
403        ["a_init", "b_init"],
404        [datum.node_name for datum in dump.dumped_tensor_data])
405
406
407def _is_public_method_name(method_name):
408  return (method_name.startswith("__") and method_name.endswith("__")
409          or not method_name.startswith("_"))
410
411
412class SessionWrapperPublicMethodParityTest(test_util.TensorFlowTestCase):
413
414  def testWrapperHasAllPublicMethodsOfSession(self):
415    session_public_methods = [
416        method_tuple[0] for method_tuple in
417        tf_inspect.getmembers(session.Session, predicate=tf_inspect.ismethod)
418        if _is_public_method_name(method_tuple[0])]
419    wrapper_public_methods = [
420        method_tuple[0] for method_tuple in
421        tf_inspect.getmembers(
422            framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod)
423        if _is_public_method_name(method_tuple[0])]
424    missing_public_methods = [
425        method for method in session_public_methods
426        if method not in wrapper_public_methods]
427    self.assertFalse(missing_public_methods)
428
429  def testWrapperHasAllPublicMethodsOfMonitoredSession(self):
430    session_public_methods = [
431        method_tuple[0] for method_tuple in
432        tf_inspect.getmembers(monitored_session.MonitoredSession,
433                              predicate=tf_inspect.ismethod)
434        if _is_public_method_name(method_tuple[0])]
435    wrapper_public_methods = [
436        method_tuple[0] for method_tuple in
437        tf_inspect.getmembers(
438            framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod)
439        if _is_public_method_name(method_tuple[0])]
440    missing_public_methods = [
441        method for method in session_public_methods
442        if method not in wrapper_public_methods]
443    self.assertFalse(missing_public_methods)
444
445
446if __name__ == "__main__":
447  googletest.main()
448