1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Framework of debug-wrapped sessions.""" 16import os 17import tempfile 18import threading 19 20import numpy as np 21 22from tensorflow.core.protobuf import config_pb2 23from tensorflow.core.protobuf import rewriter_config_pb2 24from tensorflow.python.client import session 25from tensorflow.python.debug.lib import debug_data 26from tensorflow.python.debug.wrappers import framework 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import test_util 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import math_ops 35# Import resource_variable_ops for the variables-to-tensor implicit conversion. 36from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 37from tensorflow.python.ops import variables 38from tensorflow.python.platform import googletest 39from tensorflow.python.training import monitored_session 40from tensorflow.python.util import tf_inspect 41 42 43class TestDebugWrapperSession(framework.BaseDebugWrapperSession): 44 """A concrete implementation of BaseDebugWrapperSession for test.""" 45 46 def __init__(self, sess, dump_root, observer, thread_name_filter=None): 47 # Supply dump root. 48 self._dump_root = dump_root 49 50 # Supply observer. 51 self._obs = observer 52 53 # Invoke superclass constructor. 54 framework.BaseDebugWrapperSession.__init__( 55 self, sess, thread_name_filter=thread_name_filter) 56 57 def on_session_init(self, request): 58 """Override abstract on-session-init callback method.""" 59 60 self._obs["sess_init_count"] += 1 61 self._obs["request_sess"] = request.session 62 63 return framework.OnSessionInitResponse( 64 framework.OnSessionInitAction.PROCEED) 65 66 def on_run_start(self, request): 67 """Override abstract on-run-start callback method.""" 68 69 self._obs["on_run_start_count"] += 1 70 self._obs["run_fetches"] = request.fetches 71 self._obs["run_feed_dict"] = request.feed_dict 72 73 return framework.OnRunStartResponse( 74 framework.OnRunStartAction.DEBUG_RUN, 75 ["file://" + self._dump_root]) 76 77 def on_run_end(self, request): 78 """Override abstract on-run-end callback method.""" 79 80 self._obs["on_run_end_count"] += 1 81 self._obs["performed_action"] = request.performed_action 82 self._obs["tf_error"] = request.tf_error 83 84 return framework.OnRunEndResponse() 85 86 87class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession): 88 """A concrete implementation of BaseDebugWrapperSession for test. 89 90 This class intentionally puts a bad action value in OnSessionInitResponse 91 and/or in OnRunStartAction to test the handling of such invalid cases. 92 """ 93 94 def __init__( 95 self, 96 sess, 97 bad_init_action=None, 98 bad_run_start_action=None, 99 bad_debug_urls=None): 100 """Constructor. 101 102 Args: 103 sess: The TensorFlow Session object to be wrapped. 104 bad_init_action: (str) bad action value to be returned during the 105 on-session-init callback. 106 bad_run_start_action: (str) bad action value to be returned during the 107 the on-run-start callback. 108 bad_debug_urls: Bad URL values to be returned during the on-run-start 109 callback. 110 """ 111 112 self._bad_init_action = bad_init_action 113 self._bad_run_start_action = bad_run_start_action 114 self._bad_debug_urls = bad_debug_urls 115 116 # Invoke superclass constructor. 117 framework.BaseDebugWrapperSession.__init__(self, sess) 118 119 def on_session_init(self, request): 120 if self._bad_init_action: 121 return framework.OnSessionInitResponse(self._bad_init_action) 122 else: 123 return framework.OnSessionInitResponse( 124 framework.OnSessionInitAction.PROCEED) 125 126 def on_run_start(self, request): 127 debug_urls = self._bad_debug_urls or [] 128 129 if self._bad_run_start_action: 130 return framework.OnRunStartResponse( 131 self._bad_run_start_action, debug_urls) 132 else: 133 return framework.OnRunStartResponse( 134 framework.OnRunStartAction.DEBUG_RUN, debug_urls) 135 136 def on_run_end(self, request): 137 return framework.OnRunEndResponse() 138 139 140@test_util.run_v1_only("Sessions are not available in TF 2.x") 141class DebugWrapperSessionTest(test_util.TensorFlowTestCase): 142 143 def _no_rewrite_session_config(self): 144 rewriter_config = rewriter_config_pb2.RewriterConfig( 145 disable_model_pruning=True) 146 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 147 return config_pb2.ConfigProto(graph_options=graph_options) 148 149 def setUp(self): 150 self._observer = { 151 "sess_init_count": 0, 152 "request_sess": None, 153 "on_run_start_count": 0, 154 "run_fetches": None, 155 "run_feed_dict": None, 156 "on_run_end_count": 0, 157 "performed_action": None, 158 "tf_error": None, 159 } 160 161 self._dump_root = tempfile.mkdtemp() 162 163 self._sess = session.Session(config=self._no_rewrite_session_config()) 164 165 self._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]]) 166 self._b_init_val = np.array([[2.0], [-1.0]]) 167 self._c_val = np.array([[-4.0], [6.0]]) 168 169 self._a_init = constant_op.constant( 170 self._a_init_val, shape=[2, 2], name="a_init") 171 self._b_init = constant_op.constant( 172 self._b_init_val, shape=[2, 1], name="b_init") 173 174 self._ph = array_ops.placeholder(dtype=dtypes.float64, name="ph") 175 176 self._a = variables.Variable(self._a_init, name="a1") 177 self._b = variables.Variable(self._b_init, name="b") 178 self._c = constant_op.constant(self._c_val, shape=[2, 1], name="c") 179 180 # Matrix product of a and b. 181 self._p = math_ops.matmul(self._a, self._b, name="p1") 182 183 # Matrix product of a and ph. 184 self._q = math_ops.matmul(self._a, self._ph, name="q") 185 186 # Sum of two vectors. 187 self._s = math_ops.add(self._p, self._c, name="s") 188 189 # Initialize the variables. 190 self._sess.run(self._a.initializer) 191 self._sess.run(self._b.initializer) 192 193 def tearDown(self): 194 # Tear down temporary dump directory. 195 if os.path.isdir(self._dump_root): 196 file_io.delete_recursively(self._dump_root) 197 198 ops.reset_default_graph() 199 200 def testSessionInit(self): 201 self.assertEqual(0, self._observer["sess_init_count"]) 202 203 wrapper_sess = TestDebugWrapperSession(self._sess, self._dump_root, 204 self._observer) 205 206 # Assert that on-session-init callback is invoked. 207 self.assertEqual(1, self._observer["sess_init_count"]) 208 209 # Assert that the request to the on-session-init callback carries the 210 # correct session object. 211 self.assertEqual(self._sess, self._observer["request_sess"]) 212 213 # Verify that the wrapper session implements the session.SessionInterface. 214 self.assertTrue(isinstance(wrapper_sess, session.SessionInterface)) 215 self.assertEqual(self._sess.sess_str, wrapper_sess.sess_str) 216 self.assertEqual(self._sess.graph, wrapper_sess.graph) 217 self.assertEqual(self._sess.graph_def, wrapper_sess.graph_def) 218 219 # Check that the partial_run_setup and partial_run are not implemented for 220 # the debug wrapper session. 221 with self.assertRaises(NotImplementedError): 222 wrapper_sess.partial_run_setup(self._p) 223 224 def testInteractiveSessionInit(self): 225 """The wrapper should work also on other subclasses of session.Session.""" 226 227 TestDebugWrapperSession( 228 session.InteractiveSession(), self._dump_root, self._observer) 229 230 def testSessionRun(self): 231 wrapper = TestDebugWrapperSession( 232 self._sess, self._dump_root, self._observer) 233 234 # Check initial state of the observer. 235 self.assertEqual(0, self._observer["on_run_start_count"]) 236 self.assertEqual(0, self._observer["on_run_end_count"]) 237 238 s = wrapper.run(self._s) 239 240 # Assert the run return value is correct. 241 self.assertAllClose(np.array([[3.0], [4.0]]), s) 242 243 # Assert the on-run-start method is invoked. 244 self.assertEqual(1, self._observer["on_run_start_count"]) 245 246 # Assert the on-run-start request reflects the correct fetch. 247 self.assertEqual(self._s, self._observer["run_fetches"]) 248 249 # Assert the on-run-start request reflects the correct feed_dict. 250 self.assertIsNone(self._observer["run_feed_dict"]) 251 252 # Assert the file debug URL has led to dump on the filesystem. 253 dump = debug_data.DebugDumpDir(self._dump_root) 254 self.assertEqual(7, len(dump.dumped_tensor_data)) 255 256 # Assert the on-run-end method is invoked. 257 self.assertEqual(1, self._observer["on_run_end_count"]) 258 259 # Assert the performed action field in the on-run-end callback request is 260 # correct. 261 self.assertEqual( 262 framework.OnRunStartAction.DEBUG_RUN, 263 self._observer["performed_action"]) 264 265 # No TensorFlow runtime error should have happened. 266 self.assertIsNone(self._observer["tf_error"]) 267 268 def testSessionInitInvalidSessionType(self): 269 """Attempt to wrap a non-Session-type object should cause an exception.""" 270 271 wrapper = TestDebugWrapperSessionBadAction(self._sess) 272 with self.assertRaisesRegex(TypeError, "Expected type .*; got type .*"): 273 TestDebugWrapperSessionBadAction(wrapper) 274 275 def testSessionInitBadActionValue(self): 276 with self.assertRaisesRegex( 277 ValueError, "Invalid OnSessionInitAction value: nonsense_action"): 278 TestDebugWrapperSessionBadAction( 279 self._sess, bad_init_action="nonsense_action") 280 281 def testRunStartBadActionValue(self): 282 wrapper = TestDebugWrapperSessionBadAction( 283 self._sess, bad_run_start_action="nonsense_action") 284 285 with self.assertRaisesRegex( 286 ValueError, "Invalid OnRunStartAction value: nonsense_action"): 287 wrapper.run(self._s) 288 289 def testRunStartBadURLs(self): 290 # debug_urls ought to be a list of str, not a str. So an exception should 291 # be raised during a run() call. 292 wrapper = TestDebugWrapperSessionBadAction( 293 self._sess, bad_debug_urls="file://foo") 294 295 with self.assertRaisesRegex(TypeError, "Expected type .*; got type .*"): 296 wrapper.run(self._s) 297 298 def testErrorDuringRun(self): 299 300 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 301 self._observer) 302 303 # No matrix size mismatch. 304 self.assertAllClose( 305 np.array([[11.0], [-1.0]]), 306 wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])})) 307 self.assertEqual(1, self._observer["on_run_end_count"]) 308 self.assertIsNone(self._observer["tf_error"]) 309 310 # Now there should be a matrix size mismatch error. 311 wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0], [3.0]])}) 312 self.assertEqual(2, self._observer["on_run_end_count"]) 313 self.assertTrue( 314 isinstance(self._observer["tf_error"], errors.InvalidArgumentError)) 315 316 def testUsingWrappedSessionShouldWorkAsContextManager(self): 317 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 318 self._observer) 319 320 with wrapper as sess: 321 self.assertAllClose([[3.0], [4.0]], self._s) 322 self.assertEqual(1, self._observer["on_run_start_count"]) 323 self.assertEqual(self._s, self._observer["run_fetches"]) 324 self.assertEqual(1, self._observer["on_run_end_count"]) 325 326 self.assertAllClose( 327 [[11.0], [-1.0]], 328 sess.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])})) 329 self.assertEqual(2, self._observer["on_run_start_count"]) 330 self.assertEqual(self._q, self._observer["run_fetches"]) 331 self.assertEqual(2, self._observer["on_run_end_count"]) 332 333 def testUsingWrappedSessionShouldSupportEvalWithAsDefault(self): 334 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 335 self._observer) 336 337 with wrapper.as_default(): 338 foo = constant_op.constant(42, name="foo") 339 self.assertEqual(42, self.evaluate(foo)) 340 self.assertEqual(foo, self._observer["run_fetches"]) 341 342 def testWrapperShouldSupportSessionClose(self): 343 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 344 self._observer) 345 wrapper.close() 346 347 def testWrapperThreadNameFilterMainThread(self): 348 wrapper = TestDebugWrapperSession( 349 self._sess, self._dump_root, self._observer, 350 thread_name_filter="MainThread") 351 352 child_run_output = [] 353 def child_thread_job(): 354 child_run_output.append(wrapper.run(self._b_init)) 355 356 thread = threading.Thread(name="ChildThread", target=child_thread_job) 357 thread.start() 358 self.assertAllClose(self._a_init_val, wrapper.run(self._a_init)) 359 thread.join() 360 self.assertAllClose([self._b_init_val], child_run_output) 361 362 dump = debug_data.DebugDumpDir(self._dump_root) 363 self.assertEqual(1, dump.size) 364 self.assertEqual("a_init", dump.dumped_tensor_data[0].node_name) 365 366 def testWrapperThreadNameFilterChildThread(self): 367 wrapper = TestDebugWrapperSession( 368 self._sess, self._dump_root, self._observer, 369 thread_name_filter=r"Child.*") 370 371 child_run_output = [] 372 def child_thread_job(): 373 child_run_output.append(wrapper.run(self._b_init)) 374 375 thread = threading.Thread(name="ChildThread", target=child_thread_job) 376 thread.start() 377 self.assertAllClose(self._a_init_val, wrapper.run(self._a_init)) 378 thread.join() 379 self.assertAllClose([self._b_init_val], child_run_output) 380 381 dump = debug_data.DebugDumpDir(self._dump_root) 382 self.assertEqual(1, dump.size) 383 self.assertEqual("b_init", dump.dumped_tensor_data[0].node_name) 384 385 def testWrapperThreadNameFilterBothThreads(self): 386 wrapper = TestDebugWrapperSession( 387 self._sess, self._dump_root, self._observer, 388 thread_name_filter=None) 389 390 child_run_output = [] 391 def child_thread_job(): 392 child_run_output.append(wrapper.run(self._b_init)) 393 394 thread = threading.Thread(name="ChildThread", target=child_thread_job) 395 thread.start() 396 self.assertAllClose(self._a_init_val, wrapper.run(self._a_init)) 397 thread.join() 398 self.assertAllClose([self._b_init_val], child_run_output) 399 400 dump = debug_data.DebugDumpDir(self._dump_root, validate=False) 401 self.assertEqual(2, dump.size) 402 self.assertItemsEqual( 403 ["a_init", "b_init"], 404 [datum.node_name for datum in dump.dumped_tensor_data]) 405 406 407def _is_public_method_name(method_name): 408 return (method_name.startswith("__") and method_name.endswith("__") 409 or not method_name.startswith("_")) 410 411 412class SessionWrapperPublicMethodParityTest(test_util.TensorFlowTestCase): 413 414 def testWrapperHasAllPublicMethodsOfSession(self): 415 session_public_methods = [ 416 method_tuple[0] for method_tuple in 417 tf_inspect.getmembers(session.Session, predicate=tf_inspect.ismethod) 418 if _is_public_method_name(method_tuple[0])] 419 wrapper_public_methods = [ 420 method_tuple[0] for method_tuple in 421 tf_inspect.getmembers( 422 framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod) 423 if _is_public_method_name(method_tuple[0])] 424 missing_public_methods = [ 425 method for method in session_public_methods 426 if method not in wrapper_public_methods] 427 self.assertFalse(missing_public_methods) 428 429 def testWrapperHasAllPublicMethodsOfMonitoredSession(self): 430 session_public_methods = [ 431 method_tuple[0] for method_tuple in 432 tf_inspect.getmembers(monitored_session.MonitoredSession, 433 predicate=tf_inspect.ismethod) 434 if _is_public_method_name(method_tuple[0])] 435 wrapper_public_methods = [ 436 method_tuple[0] for method_tuple in 437 tf_inspect.getmembers( 438 framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod) 439 if _is_public_method_name(method_tuple[0])] 440 missing_public_methods = [ 441 method for method in session_public_methods 442 if method not in wrapper_public_methods] 443 self.assertFalse(missing_public_methods) 444 445 446if __name__ == "__main__": 447 googletest.main() 448