1# pylint: disable=g-bad-file-header 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for basic_session_run_hooks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os.path 23import shutil 24import tempfile 25import time 26 27from tensorflow.contrib.framework.python.framework import checkpoint_utils 28from tensorflow.contrib.framework.python.ops import variables 29from tensorflow.contrib.testing.python.framework import fake_summary_writer 30from tensorflow.python.client import session as session_lib 31from tensorflow.python.data.ops import dataset_ops 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import meta_graph 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import test_util 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables as variables_lib 43import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 44from tensorflow.python.platform import gfile 45from tensorflow.python.platform import test 46from tensorflow.python.platform import tf_logging 47from tensorflow.python.summary import summary as summary_lib 48from tensorflow.python.summary.writer import writer_cache 49from tensorflow.python.training import basic_session_run_hooks 50from tensorflow.python.training import monitored_session 51from tensorflow.python.training import session_run_hook 52from tensorflow.python.training import training_util 53 54 55# Provide a realistic start time for unit tests where we need to mock out 56# calls to time.time(). 57MOCK_START_TIME = 1484695987.209386 58 59 60class MockCheckpointSaverListener( 61 basic_session_run_hooks.CheckpointSaverListener): 62 63 def __init__(self): 64 self.begin_count = 0 65 self.before_save_count = 0 66 self.after_save_count = 0 67 self.end_count = 0 68 self.ask_for_stop = False 69 70 def begin(self): 71 self.begin_count += 1 72 73 def before_save(self, session, global_step): 74 self.before_save_count += 1 75 76 def after_save(self, session, global_step): 77 self.after_save_count += 1 78 if self.ask_for_stop: 79 return True 80 81 def end(self, session, global_step): 82 self.end_count += 1 83 84 def get_counts(self): 85 return { 86 'begin': self.begin_count, 87 'before_save': self.before_save_count, 88 'after_save': self.after_save_count, 89 'end': self.end_count 90 } 91 92 93class SecondOrStepTimerTest(test.TestCase): 94 95 @test_util.run_deprecated_v1 96 def test_raise_in_both_secs_and_steps(self): 97 with self.assertRaises(ValueError): 98 basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10) 99 100 @test_util.run_deprecated_v1 101 def test_raise_in_none_secs_and_steps(self): 102 with self.assertRaises(ValueError): 103 basic_session_run_hooks.SecondOrStepTimer() 104 105 @test.mock.patch.object(time, 'time') 106 def test_every_secs(self, mock_time): 107 mock_time.return_value = MOCK_START_TIME 108 timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0) 109 self.assertTrue(timer.should_trigger_for_step(1)) 110 111 timer.update_last_triggered_step(1) 112 self.assertFalse(timer.should_trigger_for_step(1)) 113 self.assertFalse(timer.should_trigger_for_step(2)) 114 115 mock_time.return_value += 1.0 116 self.assertFalse(timer.should_trigger_for_step(1)) 117 self.assertTrue(timer.should_trigger_for_step(2)) 118 119 def test_every_steps(self): 120 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3) 121 self.assertTrue(timer.should_trigger_for_step(1)) 122 123 timer.update_last_triggered_step(1) 124 self.assertFalse(timer.should_trigger_for_step(1)) 125 self.assertFalse(timer.should_trigger_for_step(2)) 126 self.assertFalse(timer.should_trigger_for_step(3)) 127 self.assertTrue(timer.should_trigger_for_step(4)) 128 129 def test_update_last_triggered_step(self): 130 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1) 131 132 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1) 133 self.assertEqual(None, elapsed_secs) 134 self.assertEqual(None, elapsed_steps) 135 136 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5) 137 self.assertLess(0, elapsed_secs) 138 self.assertEqual(4, elapsed_steps) 139 140 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7) 141 self.assertLess(0, elapsed_secs) 142 self.assertEqual(2, elapsed_steps) 143 144 145class StopAtStepTest(test.TestCase): 146 147 def test_raise_in_both_last_step_and_num_steps(self): 148 with self.assertRaises(ValueError): 149 basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20) 150 151 def test_stop_based_on_last_step(self): 152 h = basic_session_run_hooks.StopAtStepHook(last_step=10) 153 with ops.Graph().as_default(): 154 global_step = variables.get_or_create_global_step() 155 no_op = control_flow_ops.no_op() 156 h.begin() 157 with session_lib.Session() as sess: 158 mon_sess = monitored_session._HookedSession(sess, [h]) 159 sess.run(state_ops.assign(global_step, 5)) 160 h.after_create_session(sess, None) 161 mon_sess.run(no_op) 162 self.assertFalse(mon_sess.should_stop()) 163 sess.run(state_ops.assign(global_step, 9)) 164 mon_sess.run(no_op) 165 self.assertFalse(mon_sess.should_stop()) 166 sess.run(state_ops.assign(global_step, 10)) 167 mon_sess.run(no_op) 168 self.assertTrue(mon_sess.should_stop()) 169 sess.run(state_ops.assign(global_step, 11)) 170 mon_sess._should_stop = False 171 mon_sess.run(no_op) 172 self.assertTrue(mon_sess.should_stop()) 173 174 def test_stop_based_on_num_step(self): 175 h = basic_session_run_hooks.StopAtStepHook(num_steps=10) 176 177 with ops.Graph().as_default(): 178 global_step = variables.get_or_create_global_step() 179 no_op = control_flow_ops.no_op() 180 h.begin() 181 with session_lib.Session() as sess: 182 mon_sess = monitored_session._HookedSession(sess, [h]) 183 sess.run(state_ops.assign(global_step, 5)) 184 h.after_create_session(sess, None) 185 mon_sess.run(no_op) 186 self.assertFalse(mon_sess.should_stop()) 187 sess.run(state_ops.assign(global_step, 13)) 188 mon_sess.run(no_op) 189 self.assertFalse(mon_sess.should_stop()) 190 sess.run(state_ops.assign(global_step, 14)) 191 mon_sess.run(no_op) 192 self.assertFalse(mon_sess.should_stop()) 193 sess.run(state_ops.assign(global_step, 15)) 194 mon_sess.run(no_op) 195 self.assertTrue(mon_sess.should_stop()) 196 sess.run(state_ops.assign(global_step, 16)) 197 mon_sess._should_stop = False 198 mon_sess.run(no_op) 199 self.assertTrue(mon_sess.should_stop()) 200 201 def test_stop_based_with_multiple_steps(self): 202 h = basic_session_run_hooks.StopAtStepHook(num_steps=10) 203 204 with ops.Graph().as_default(): 205 global_step = variables.get_or_create_global_step() 206 no_op = control_flow_ops.no_op() 207 h.begin() 208 with session_lib.Session() as sess: 209 mon_sess = monitored_session._HookedSession(sess, [h]) 210 sess.run(state_ops.assign(global_step, 5)) 211 h.after_create_session(sess, None) 212 mon_sess.run(no_op) 213 self.assertFalse(mon_sess.should_stop()) 214 sess.run(state_ops.assign(global_step, 15)) 215 mon_sess.run(no_op) 216 self.assertTrue(mon_sess.should_stop()) 217 218 219class LoggingTensorHookTest(test.TestCase): 220 221 def setUp(self): 222 # Mock out logging calls so we can verify whether correct tensors are being 223 # monitored. 224 self._actual_log = tf_logging.info 225 self.logged_message = None 226 227 def mock_log(*args, **kwargs): 228 self.logged_message = args 229 self._actual_log(*args, **kwargs) 230 231 tf_logging.info = mock_log 232 233 def tearDown(self): 234 tf_logging.info = self._actual_log 235 236 def test_illegal_args(self): 237 with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): 238 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0) 239 with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): 240 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10) 241 with self.assertRaisesRegexp(ValueError, 'xactly one of'): 242 basic_session_run_hooks.LoggingTensorHook( 243 tensors=['t'], every_n_iter=5, every_n_secs=5) 244 with self.assertRaisesRegexp(ValueError, 'xactly one of'): 245 basic_session_run_hooks.LoggingTensorHook(tensors=['t']) 246 247 def test_print_at_end_only(self): 248 with ops.Graph().as_default(), session_lib.Session() as sess: 249 t = constant_op.constant(42.0, name='foo') 250 train_op = constant_op.constant(3) 251 hook = basic_session_run_hooks.LoggingTensorHook( 252 tensors=[t.name], at_end=True) 253 hook.begin() 254 mon_sess = monitored_session._HookedSession(sess, [hook]) 255 self.evaluate(variables_lib.global_variables_initializer()) 256 self.logged_message = '' 257 for _ in range(3): 258 mon_sess.run(train_op) 259 # assertNotRegexpMatches is not supported by python 3.1 and later 260 self.assertEqual(str(self.logged_message).find(t.name), -1) 261 262 hook.end(sess) 263 self.assertRegexpMatches(str(self.logged_message), t.name) 264 265 def _validate_print_every_n_steps(self, sess, at_end): 266 t = constant_op.constant(42.0, name='foo') 267 268 train_op = constant_op.constant(3) 269 hook = basic_session_run_hooks.LoggingTensorHook( 270 tensors=[t.name], every_n_iter=10, at_end=at_end) 271 hook.begin() 272 mon_sess = monitored_session._HookedSession(sess, [hook]) 273 self.evaluate(variables_lib.global_variables_initializer()) 274 mon_sess.run(train_op) 275 self.assertRegexpMatches(str(self.logged_message), t.name) 276 for _ in range(3): 277 self.logged_message = '' 278 for _ in range(9): 279 mon_sess.run(train_op) 280 # assertNotRegexpMatches is not supported by python 3.1 and later 281 self.assertEqual(str(self.logged_message).find(t.name), -1) 282 mon_sess.run(train_op) 283 self.assertRegexpMatches(str(self.logged_message), t.name) 284 285 # Add additional run to verify proper reset when called multiple times. 286 self.logged_message = '' 287 mon_sess.run(train_op) 288 # assertNotRegexpMatches is not supported by python 3.1 and later 289 self.assertEqual(str(self.logged_message).find(t.name), -1) 290 291 self.logged_message = '' 292 hook.end(sess) 293 if at_end: 294 self.assertRegexpMatches(str(self.logged_message), t.name) 295 else: 296 # assertNotRegexpMatches is not supported by python 3.1 and later 297 self.assertEqual(str(self.logged_message).find(t.name), -1) 298 299 def test_print_every_n_steps(self): 300 with ops.Graph().as_default(), session_lib.Session() as sess: 301 self._validate_print_every_n_steps(sess, at_end=False) 302 # Verify proper reset. 303 self._validate_print_every_n_steps(sess, at_end=False) 304 305 def test_print_every_n_steps_and_end(self): 306 with ops.Graph().as_default(), session_lib.Session() as sess: 307 self._validate_print_every_n_steps(sess, at_end=True) 308 # Verify proper reset. 309 self._validate_print_every_n_steps(sess, at_end=True) 310 311 def test_print_first_step(self): 312 # if it runs every iteration, first iteration has None duration. 313 with ops.Graph().as_default(), session_lib.Session() as sess: 314 t = constant_op.constant(42.0, name='foo') 315 train_op = constant_op.constant(3) 316 hook = basic_session_run_hooks.LoggingTensorHook( 317 tensors={'foo': t}, every_n_iter=1) 318 hook.begin() 319 mon_sess = monitored_session._HookedSession(sess, [hook]) 320 self.evaluate(variables_lib.global_variables_initializer()) 321 mon_sess.run(train_op) 322 self.assertRegexpMatches(str(self.logged_message), 'foo') 323 # in first run, elapsed time is None. 324 self.assertEqual(str(self.logged_message).find('sec'), -1) 325 326 def _validate_print_every_n_secs(self, sess, at_end, mock_time): 327 t = constant_op.constant(42.0, name='foo') 328 train_op = constant_op.constant(3) 329 330 hook = basic_session_run_hooks.LoggingTensorHook( 331 tensors=[t.name], every_n_secs=1.0, at_end=at_end) 332 hook.begin() 333 mon_sess = monitored_session._HookedSession(sess, [hook]) 334 self.evaluate(variables_lib.global_variables_initializer()) 335 336 mon_sess.run(train_op) 337 self.assertRegexpMatches(str(self.logged_message), t.name) 338 339 # assertNotRegexpMatches is not supported by python 3.1 and later 340 self.logged_message = '' 341 mon_sess.run(train_op) 342 self.assertEqual(str(self.logged_message).find(t.name), -1) 343 mock_time.return_value += 1.0 344 345 self.logged_message = '' 346 mon_sess.run(train_op) 347 self.assertRegexpMatches(str(self.logged_message), t.name) 348 349 self.logged_message = '' 350 hook.end(sess) 351 if at_end: 352 self.assertRegexpMatches(str(self.logged_message), t.name) 353 else: 354 # assertNotRegexpMatches is not supported by python 3.1 and later 355 self.assertEqual(str(self.logged_message).find(t.name), -1) 356 357 @test.mock.patch.object(time, 'time') 358 def test_print_every_n_secs(self, mock_time): 359 with ops.Graph().as_default(), session_lib.Session() as sess: 360 mock_time.return_value = MOCK_START_TIME 361 self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time) 362 # Verify proper reset. 363 self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time) 364 365 @test.mock.patch.object(time, 'time') 366 def test_print_every_n_secs_and_end(self, mock_time): 367 with ops.Graph().as_default(), session_lib.Session() as sess: 368 mock_time.return_value = MOCK_START_TIME 369 self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time) 370 # Verify proper reset. 371 self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time) 372 373 def test_print_formatter(self): 374 with ops.Graph().as_default(), session_lib.Session() as sess: 375 t = constant_op.constant(42.0, name='foo') 376 train_op = constant_op.constant(3) 377 hook = basic_session_run_hooks.LoggingTensorHook( 378 tensors=[t.name], every_n_iter=10, 379 formatter=lambda items: 'qqq=%s' % items[t.name]) 380 hook.begin() 381 mon_sess = monitored_session._HookedSession(sess, [hook]) 382 self.evaluate(variables_lib.global_variables_initializer()) 383 mon_sess.run(train_op) 384 self.assertEqual(self.logged_message[0], 'qqq=42.0') 385 386 387class CheckpointSaverHookTest(test.TestCase): 388 389 def setUp(self): 390 self.model_dir = tempfile.mkdtemp() 391 self.graph = ops.Graph() 392 with self.graph.as_default(): 393 self.scaffold = monitored_session.Scaffold() 394 self.global_step = variables.get_or_create_global_step() 395 self.train_op = training_util._increment_global_step(1) 396 397 def tearDown(self): 398 shutil.rmtree(self.model_dir, ignore_errors=True) 399 400 def test_saves_when_saver_and_scaffold_both_missing(self): 401 with self.graph.as_default(): 402 hook = basic_session_run_hooks.CheckpointSaverHook( 403 self.model_dir, save_steps=1) 404 hook.begin() 405 self.scaffold.finalize() 406 with session_lib.Session() as sess: 407 sess.run(self.scaffold.init_op) 408 mon_sess = monitored_session._HookedSession(sess, [hook]) 409 mon_sess.run(self.train_op) 410 self.assertEqual(1, 411 checkpoint_utils.load_variable(self.model_dir, 412 self.global_step.name)) 413 414 def test_raise_when_saver_and_scaffold_both_present(self): 415 with self.assertRaises(ValueError): 416 basic_session_run_hooks.CheckpointSaverHook( 417 self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold) 418 419 @test_util.run_deprecated_v1 420 def test_raise_in_both_secs_and_steps(self): 421 with self.assertRaises(ValueError): 422 basic_session_run_hooks.CheckpointSaverHook( 423 self.model_dir, save_secs=10, save_steps=20) 424 425 @test_util.run_deprecated_v1 426 def test_raise_in_none_secs_and_steps(self): 427 with self.assertRaises(ValueError): 428 basic_session_run_hooks.CheckpointSaverHook(self.model_dir) 429 430 def test_save_secs_saves_in_first_step(self): 431 with self.graph.as_default(): 432 hook = basic_session_run_hooks.CheckpointSaverHook( 433 self.model_dir, save_secs=2, scaffold=self.scaffold) 434 hook.begin() 435 self.scaffold.finalize() 436 with session_lib.Session() as sess: 437 sess.run(self.scaffold.init_op) 438 mon_sess = monitored_session._HookedSession(sess, [hook]) 439 mon_sess.run(self.train_op) 440 self.assertEqual(1, 441 checkpoint_utils.load_variable(self.model_dir, 442 self.global_step.name)) 443 444 def test_save_secs_calls_listeners_at_begin_and_end(self): 445 with self.graph.as_default(): 446 listener = MockCheckpointSaverListener() 447 hook = basic_session_run_hooks.CheckpointSaverHook( 448 self.model_dir, 449 save_secs=2, 450 scaffold=self.scaffold, 451 listeners=[listener]) 452 hook.begin() 453 self.scaffold.finalize() 454 with session_lib.Session() as sess: 455 sess.run(self.scaffold.init_op) 456 mon_sess = monitored_session._HookedSession(sess, [hook]) 457 mon_sess.run(self.train_op) # hook runs here 458 mon_sess.run(self.train_op) # hook won't run here, so it does at end 459 hook.end(sess) # hook runs here 460 self.assertEqual({ 461 'begin': 1, 462 'before_save': 2, 463 'after_save': 2, 464 'end': 1 465 }, listener.get_counts()) 466 467 def test_listener_with_monitored_session(self): 468 with ops.Graph().as_default(): 469 scaffold = monitored_session.Scaffold() 470 global_step = variables.get_or_create_global_step() 471 train_op = training_util._increment_global_step(1) 472 listener = MockCheckpointSaverListener() 473 hook = basic_session_run_hooks.CheckpointSaverHook( 474 self.model_dir, 475 save_steps=1, 476 scaffold=scaffold, 477 listeners=[listener]) 478 with monitored_session.SingularMonitoredSession( 479 hooks=[hook], 480 scaffold=scaffold, 481 checkpoint_dir=self.model_dir) as sess: 482 sess.run(train_op) 483 sess.run(train_op) 484 global_step_val = sess.raw_session().run(global_step) 485 listener_counts = listener.get_counts() 486 self.assertEqual(2, global_step_val) 487 self.assertEqual({ 488 'begin': 1, 489 'before_save': 3, 490 'after_save': 3, 491 'end': 1 492 }, listener_counts) 493 494 def test_listener_stops_training_in_after_save(self): 495 with ops.Graph().as_default(): 496 scaffold = monitored_session.Scaffold() 497 variables.get_or_create_global_step() 498 train_op = training_util._increment_global_step(1) 499 listener = MockCheckpointSaverListener() 500 hook = basic_session_run_hooks.CheckpointSaverHook( 501 self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener]) 502 with monitored_session.SingularMonitoredSession( 503 hooks=[hook], scaffold=scaffold, 504 checkpoint_dir=self.model_dir) as sess: 505 sess.run(train_op) 506 self.assertFalse(sess.should_stop()) 507 sess.run(train_op) 508 self.assertFalse(sess.should_stop()) 509 listener.ask_for_stop = True 510 sess.run(train_op) 511 self.assertTrue(sess.should_stop()) 512 513 def test_listener_with_default_saver(self): 514 with ops.Graph().as_default(): 515 global_step = variables.get_or_create_global_step() 516 train_op = training_util._increment_global_step(1) 517 listener = MockCheckpointSaverListener() 518 hook = basic_session_run_hooks.CheckpointSaverHook( 519 self.model_dir, 520 save_steps=1, 521 listeners=[listener]) 522 with monitored_session.SingularMonitoredSession( 523 hooks=[hook], 524 checkpoint_dir=self.model_dir) as sess: 525 sess.run(train_op) 526 sess.run(train_op) 527 global_step_val = sess.raw_session().run(global_step) 528 listener_counts = listener.get_counts() 529 self.assertEqual(2, global_step_val) 530 self.assertEqual({ 531 'begin': 1, 532 'before_save': 3, 533 'after_save': 3, 534 'end': 1 535 }, listener_counts) 536 537 with ops.Graph().as_default(): 538 global_step = variables.get_or_create_global_step() 539 with monitored_session.SingularMonitoredSession( 540 checkpoint_dir=self.model_dir) as sess2: 541 global_step_saved_val = sess2.run(global_step) 542 self.assertEqual(2, global_step_saved_val) 543 544 def test_two_listeners_with_default_saver(self): 545 with ops.Graph().as_default(): 546 global_step = variables.get_or_create_global_step() 547 train_op = training_util._increment_global_step(1) 548 listener1 = MockCheckpointSaverListener() 549 listener2 = MockCheckpointSaverListener() 550 hook = basic_session_run_hooks.CheckpointSaverHook( 551 self.model_dir, 552 save_steps=1, 553 listeners=[listener1, listener2]) 554 with monitored_session.SingularMonitoredSession( 555 hooks=[hook], 556 checkpoint_dir=self.model_dir) as sess: 557 sess.run(train_op) 558 sess.run(train_op) 559 global_step_val = sess.raw_session().run(global_step) 560 listener1_counts = listener1.get_counts() 561 listener2_counts = listener2.get_counts() 562 self.assertEqual(2, global_step_val) 563 self.assertEqual({ 564 'begin': 1, 565 'before_save': 3, 566 'after_save': 3, 567 'end': 1 568 }, listener1_counts) 569 self.assertEqual(listener1_counts, listener2_counts) 570 571 with ops.Graph().as_default(): 572 global_step = variables.get_or_create_global_step() 573 with monitored_session.SingularMonitoredSession( 574 checkpoint_dir=self.model_dir) as sess2: 575 global_step_saved_val = sess2.run(global_step) 576 self.assertEqual(2, global_step_saved_val) 577 578 @test.mock.patch.object(time, 'time') 579 def test_save_secs_saves_periodically(self, mock_time): 580 with self.graph.as_default(): 581 mock_time.return_value = MOCK_START_TIME 582 hook = basic_session_run_hooks.CheckpointSaverHook( 583 self.model_dir, save_secs=2, scaffold=self.scaffold) 584 hook.begin() 585 self.scaffold.finalize() 586 587 with session_lib.Session() as sess: 588 sess.run(self.scaffold.init_op) 589 mon_sess = monitored_session._HookedSession(sess, [hook]) 590 591 mock_time.return_value = MOCK_START_TIME 592 mon_sess.run(self.train_op) # Saved. 593 594 mock_time.return_value = MOCK_START_TIME + 0.5 595 mon_sess.run(self.train_op) # Not saved. 596 597 self.assertEqual(1, 598 checkpoint_utils.load_variable(self.model_dir, 599 self.global_step.name)) 600 601 # Simulate 2.5 seconds of sleep. 602 mock_time.return_value = MOCK_START_TIME + 2.5 603 mon_sess.run(self.train_op) # Saved. 604 605 mock_time.return_value = MOCK_START_TIME + 2.6 606 mon_sess.run(self.train_op) # Not saved. 607 608 mock_time.return_value = MOCK_START_TIME + 2.7 609 mon_sess.run(self.train_op) # Not saved. 610 611 self.assertEqual(3, 612 checkpoint_utils.load_variable(self.model_dir, 613 self.global_step.name)) 614 615 # Simulate 7.5 more seconds of sleep (10 seconds from start. 616 mock_time.return_value = MOCK_START_TIME + 10 617 mon_sess.run(self.train_op) # Saved. 618 self.assertEqual(6, 619 checkpoint_utils.load_variable(self.model_dir, 620 self.global_step.name)) 621 622 @test.mock.patch.object(time, 'time') 623 def test_save_secs_calls_listeners_periodically(self, mock_time): 624 with self.graph.as_default(): 625 mock_time.return_value = MOCK_START_TIME 626 listener = MockCheckpointSaverListener() 627 hook = basic_session_run_hooks.CheckpointSaverHook( 628 self.model_dir, 629 save_secs=2, 630 scaffold=self.scaffold, 631 listeners=[listener]) 632 hook.begin() 633 self.scaffold.finalize() 634 with session_lib.Session() as sess: 635 sess.run(self.scaffold.init_op) 636 mon_sess = monitored_session._HookedSession(sess, [hook]) 637 638 mock_time.return_value = MOCK_START_TIME + 0.5 639 mon_sess.run(self.train_op) # hook runs here 640 641 mock_time.return_value = MOCK_START_TIME + 0.5 642 mon_sess.run(self.train_op) 643 644 mock_time.return_value = MOCK_START_TIME + 3.0 645 mon_sess.run(self.train_op) # hook runs here 646 647 mock_time.return_value = MOCK_START_TIME + 3.5 648 mon_sess.run(self.train_op) 649 650 mock_time.return_value = MOCK_START_TIME + 4.0 651 mon_sess.run(self.train_op) 652 653 mock_time.return_value = MOCK_START_TIME + 6.5 654 mon_sess.run(self.train_op) # hook runs here 655 656 mock_time.return_value = MOCK_START_TIME + 7.0 657 mon_sess.run(self.train_op) # hook won't run here, so it does at end 658 659 mock_time.return_value = MOCK_START_TIME + 7.5 660 hook.end(sess) # hook runs here 661 self.assertEqual({ 662 'begin': 1, 663 'before_save': 4, 664 'after_save': 4, 665 'end': 1 666 }, listener.get_counts()) 667 668 def test_save_steps_saves_in_first_step(self): 669 with self.graph.as_default(): 670 hook = basic_session_run_hooks.CheckpointSaverHook( 671 self.model_dir, save_steps=2, scaffold=self.scaffold) 672 hook.begin() 673 self.scaffold.finalize() 674 with session_lib.Session() as sess: 675 sess.run(self.scaffold.init_op) 676 mon_sess = monitored_session._HookedSession(sess, [hook]) 677 mon_sess.run(self.train_op) 678 self.assertEqual(1, 679 checkpoint_utils.load_variable(self.model_dir, 680 self.global_step.name)) 681 682 def test_save_steps_saves_periodically(self): 683 with self.graph.as_default(): 684 hook = basic_session_run_hooks.CheckpointSaverHook( 685 self.model_dir, save_steps=2, scaffold=self.scaffold) 686 hook.begin() 687 self.scaffold.finalize() 688 with session_lib.Session() as sess: 689 sess.run(self.scaffold.init_op) 690 mon_sess = monitored_session._HookedSession(sess, [hook]) 691 mon_sess.run(self.train_op) 692 mon_sess.run(self.train_op) 693 # Not saved 694 self.assertEqual(1, 695 checkpoint_utils.load_variable(self.model_dir, 696 self.global_step.name)) 697 mon_sess.run(self.train_op) 698 # saved 699 self.assertEqual(3, 700 checkpoint_utils.load_variable(self.model_dir, 701 self.global_step.name)) 702 mon_sess.run(self.train_op) 703 # Not saved 704 self.assertEqual(3, 705 checkpoint_utils.load_variable(self.model_dir, 706 self.global_step.name)) 707 mon_sess.run(self.train_op) 708 # saved 709 self.assertEqual(5, 710 checkpoint_utils.load_variable(self.model_dir, 711 self.global_step.name)) 712 713 def test_save_saves_at_end(self): 714 with self.graph.as_default(): 715 hook = basic_session_run_hooks.CheckpointSaverHook( 716 self.model_dir, save_secs=2, scaffold=self.scaffold) 717 hook.begin() 718 self.scaffold.finalize() 719 with session_lib.Session() as sess: 720 sess.run(self.scaffold.init_op) 721 mon_sess = monitored_session._HookedSession(sess, [hook]) 722 mon_sess.run(self.train_op) 723 mon_sess.run(self.train_op) 724 hook.end(sess) 725 self.assertEqual(2, 726 checkpoint_utils.load_variable(self.model_dir, 727 self.global_step.name)) 728 729 def test_summary_writer_defs(self): 730 fake_summary_writer.FakeSummaryWriter.install() 731 writer_cache.FileWriterCache.clear() 732 summary_writer = writer_cache.FileWriterCache.get(self.model_dir) 733 734 with self.graph.as_default(): 735 hook = basic_session_run_hooks.CheckpointSaverHook( 736 self.model_dir, save_steps=2, scaffold=self.scaffold) 737 hook.begin() 738 self.scaffold.finalize() 739 with session_lib.Session() as sess: 740 sess.run(self.scaffold.init_op) 741 mon_sess = monitored_session._HookedSession(sess, [hook]) 742 hook.after_create_session(sess, None) 743 mon_sess.run(self.train_op) 744 summary_writer.assert_summaries( 745 test_case=self, 746 expected_logdir=self.model_dir, 747 expected_added_meta_graphs=[ 748 meta_graph.create_meta_graph_def( 749 graph_def=self.graph.as_graph_def(add_shapes=True), 750 saver_def=self.scaffold.saver.saver_def) 751 ]) 752 753 fake_summary_writer.FakeSummaryWriter.uninstall() 754 755 def test_save_checkpoint_before_first_train_step(self): 756 with self.graph.as_default(): 757 hook = basic_session_run_hooks.CheckpointSaverHook( 758 self.model_dir, save_steps=2, scaffold=self.scaffold) 759 hook.begin() 760 self.scaffold.finalize() 761 with session_lib.Session() as sess: 762 mon_sess = monitored_session._HookedSession(sess, [hook]) 763 sess.run(self.scaffold.init_op) 764 hook.after_create_session(sess, None) 765 # Verifies that checkpoint is saved at step 0. 766 self.assertEqual(0, 767 checkpoint_utils.load_variable(self.model_dir, 768 self.global_step.name)) 769 # Verifies that no checkpoint is saved after one training step. 770 mon_sess.run(self.train_op) 771 self.assertEqual(0, 772 checkpoint_utils.load_variable(self.model_dir, 773 self.global_step.name)) 774 # Verifies that checkpoint is saved after save_steps. 775 mon_sess.run(self.train_op) 776 self.assertEqual(2, 777 checkpoint_utils.load_variable(self.model_dir, 778 self.global_step.name)) 779 780 781class CheckpointSaverHookMultiStepTest(test.TestCase): 782 783 def setUp(self): 784 self.model_dir = tempfile.mkdtemp() 785 self.graph = ops.Graph() 786 self.steps_per_run = 5 787 with self.graph.as_default(): 788 self.scaffold = monitored_session.Scaffold() 789 self.global_step = variables.get_or_create_global_step() 790 self.train_op = training_util._increment_global_step(self.steps_per_run) 791 792 def tearDown(self): 793 shutil.rmtree(self.model_dir, ignore_errors=True) 794 795 def test_save_steps_saves_in_first_step(self): 796 with self.graph.as_default(): 797 hook = basic_session_run_hooks.CheckpointSaverHook( 798 self.model_dir, 799 save_steps=2*self.steps_per_run, 800 scaffold=self.scaffold) 801 hook._set_steps_per_run(self.steps_per_run) 802 hook.begin() 803 self.scaffold.finalize() 804 with session_lib.Session() as sess: 805 sess.run(self.scaffold.init_op) 806 mon_sess = monitored_session._HookedSession(sess, [hook]) 807 mon_sess.run(self.train_op) 808 self.assertEqual(5, 809 checkpoint_utils.load_variable(self.model_dir, 810 self.global_step.name)) 811 812 def test_save_steps_saves_periodically(self): 813 with self.graph.as_default(): 814 hook = basic_session_run_hooks.CheckpointSaverHook( 815 self.model_dir, 816 save_steps=2*self.steps_per_run, 817 scaffold=self.scaffold) 818 hook._set_steps_per_run(self.steps_per_run) 819 hook.begin() 820 self.scaffold.finalize() 821 with session_lib.Session() as sess: 822 sess.run(self.scaffold.init_op) 823 mon_sess = monitored_session._HookedSession(sess, [hook]) 824 mon_sess.run(self.train_op) 825 # Saved (step=5) 826 self.assertEqual(5, 827 checkpoint_utils.load_variable(self.model_dir, 828 self.global_step.name)) 829 830 mon_sess.run(self.train_op) 831 # Not saved (step=10) 832 self.assertEqual(5, 833 checkpoint_utils.load_variable(self.model_dir, 834 self.global_step.name)) 835 836 mon_sess.run(self.train_op) 837 # Saved (step=15) 838 self.assertEqual(15, 839 checkpoint_utils.load_variable(self.model_dir, 840 self.global_step.name)) 841 842 mon_sess.run(self.train_op) 843 # Not saved (step=20) 844 self.assertEqual(15, 845 checkpoint_utils.load_variable(self.model_dir, 846 self.global_step.name)) 847 848 mon_sess.run(self.train_op) 849 # Saved (step=25) 850 self.assertEqual(25, 851 checkpoint_utils.load_variable(self.model_dir, 852 self.global_step.name)) 853 854 def test_save_steps_saves_at_end(self): 855 with self.graph.as_default(): 856 hook = basic_session_run_hooks.CheckpointSaverHook( 857 self.model_dir, 858 save_steps=2*self.steps_per_run, 859 scaffold=self.scaffold) 860 hook._set_steps_per_run(self.steps_per_run) 861 hook.begin() 862 self.scaffold.finalize() 863 with session_lib.Session() as sess: 864 sess.run(self.scaffold.init_op) 865 mon_sess = monitored_session._HookedSession(sess, [hook]) 866 mon_sess.run(self.train_op) 867 mon_sess.run(self.train_op) 868 hook.end(sess) 869 self.assertEqual(10, 870 checkpoint_utils.load_variable(self.model_dir, 871 self.global_step.name)) 872 873 874class ResourceCheckpointSaverHookTest(test.TestCase): 875 876 def setUp(self): 877 self.model_dir = tempfile.mkdtemp() 878 self.graph = ops.Graph() 879 with self.graph.as_default(): 880 self.scaffold = monitored_session.Scaffold() 881 with variable_scope.variable_scope('foo', use_resource=True): 882 self.global_step = training_util.get_or_create_global_step() 883 self.train_op = training_util._increment_global_step(1) 884 885 def test_save_steps_saves_periodically(self): 886 with self.graph.as_default(): 887 hook = basic_session_run_hooks.CheckpointSaverHook( 888 self.model_dir, save_steps=2, scaffold=self.scaffold) 889 hook.begin() 890 self.scaffold.finalize() 891 with session_lib.Session() as sess: 892 sess.run(self.scaffold.init_op) 893 mon_sess = monitored_session._HookedSession(sess, [hook]) 894 mon_sess.run(self.train_op) 895 mon_sess.run(self.train_op) 896 # Not saved 897 self.assertEqual(1, 898 checkpoint_utils.load_variable(self.model_dir, 899 self.global_step.name)) 900 mon_sess.run(self.train_op) 901 # saved 902 self.assertEqual(3, 903 checkpoint_utils.load_variable(self.model_dir, 904 self.global_step.name)) 905 mon_sess.run(self.train_op) 906 # Not saved 907 self.assertEqual(3, 908 checkpoint_utils.load_variable(self.model_dir, 909 self.global_step.name)) 910 mon_sess.run(self.train_op) 911 # saved 912 self.assertEqual(5, 913 checkpoint_utils.load_variable(self.model_dir, 914 self.global_step.name)) 915 916 917class StepCounterHookTest(test.TestCase): 918 919 def setUp(self): 920 self.log_dir = tempfile.mkdtemp() 921 922 def tearDown(self): 923 shutil.rmtree(self.log_dir, ignore_errors=True) 924 925 @test.mock.patch.object(time, 'time') 926 def test_step_counter_every_n_steps(self, mock_time): 927 mock_time.return_value = MOCK_START_TIME 928 with ops.Graph().as_default() as g, session_lib.Session() as sess: 929 variables.get_or_create_global_step() 930 train_op = training_util._increment_global_step(1) 931 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 932 hook = basic_session_run_hooks.StepCounterHook( 933 summary_writer=summary_writer, every_n_steps=10) 934 hook.begin() 935 self.evaluate(variables_lib.global_variables_initializer()) 936 mon_sess = monitored_session._HookedSession(sess, [hook]) 937 with test.mock.patch.object(tf_logging, 'warning') as mock_log: 938 for _ in range(30): 939 mock_time.return_value += 0.01 940 mon_sess.run(train_op) 941 # logging.warning should not be called. 942 self.assertIsNone(mock_log.call_args) 943 hook.end(sess) 944 summary_writer.assert_summaries( 945 test_case=self, 946 expected_logdir=self.log_dir, 947 expected_graph=g, 948 expected_summaries={}) 949 self.assertItemsEqual([11, 21], summary_writer.summaries.keys()) 950 for step in [11, 21]: 951 summary_value = summary_writer.summaries[step][0].value[0] 952 self.assertEqual('global_step/sec', summary_value.tag) 953 self.assertGreater(summary_value.simple_value, 0) 954 955 @test.mock.patch.object(time, 'time') 956 def test_step_counter_every_n_secs(self, mock_time): 957 mock_time.return_value = MOCK_START_TIME 958 with ops.Graph().as_default() as g, session_lib.Session() as sess: 959 variables.get_or_create_global_step() 960 train_op = training_util._increment_global_step(1) 961 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 962 hook = basic_session_run_hooks.StepCounterHook( 963 summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1) 964 965 hook.begin() 966 self.evaluate(variables_lib.global_variables_initializer()) 967 mon_sess = monitored_session._HookedSession(sess, [hook]) 968 mon_sess.run(train_op) 969 mock_time.return_value += 0.2 970 mon_sess.run(train_op) 971 mock_time.return_value += 0.2 972 mon_sess.run(train_op) 973 hook.end(sess) 974 975 summary_writer.assert_summaries( 976 test_case=self, 977 expected_logdir=self.log_dir, 978 expected_graph=g, 979 expected_summaries={}) 980 self.assertTrue(summary_writer.summaries, 'No summaries were created.') 981 self.assertItemsEqual([2, 3], summary_writer.summaries.keys()) 982 for summary in summary_writer.summaries.values(): 983 summary_value = summary[0].value[0] 984 self.assertEqual('global_step/sec', summary_value.tag) 985 self.assertGreater(summary_value.simple_value, 0) 986 987 def test_global_step_name(self): 988 with ops.Graph().as_default() as g, session_lib.Session() as sess: 989 with variable_scope.variable_scope('bar'): 990 variable_scope.get_variable( 991 'foo', 992 initializer=0, 993 trainable=False, 994 collections=[ 995 ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES 996 ]) 997 train_op = training_util._increment_global_step(1) 998 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 999 hook = basic_session_run_hooks.StepCounterHook( 1000 summary_writer=summary_writer, every_n_steps=1, every_n_secs=None) 1001 1002 hook.begin() 1003 self.evaluate(variables_lib.global_variables_initializer()) 1004 mon_sess = monitored_session._HookedSession(sess, [hook]) 1005 mon_sess.run(train_op) 1006 mon_sess.run(train_op) 1007 hook.end(sess) 1008 1009 summary_writer.assert_summaries( 1010 test_case=self, 1011 expected_logdir=self.log_dir, 1012 expected_graph=g, 1013 expected_summaries={}) 1014 self.assertTrue(summary_writer.summaries, 'No summaries were created.') 1015 self.assertItemsEqual([2], summary_writer.summaries.keys()) 1016 summary_value = summary_writer.summaries[2][0].value[0] 1017 self.assertEqual('bar/foo/sec', summary_value.tag) 1018 1019 def test_log_warning_if_global_step_not_increased(self): 1020 with ops.Graph().as_default(), session_lib.Session() as sess: 1021 variables.get_or_create_global_step() 1022 train_op = training_util._increment_global_step(0) # keep same. 1023 self.evaluate(variables_lib.global_variables_initializer()) 1024 hook = basic_session_run_hooks.StepCounterHook( 1025 every_n_steps=1, every_n_secs=None) 1026 hook.begin() 1027 mon_sess = monitored_session._HookedSession(sess, [hook]) 1028 mon_sess.run(train_op) # Run one step to record global step. 1029 with test.mock.patch.object(tf_logging, 'warning') as mock_log: 1030 for _ in range(30): 1031 mon_sess.run(train_op) 1032 self.assertRegexpMatches( 1033 str(mock_log.call_args), 1034 'global step.*has not been increased') 1035 hook.end(sess) 1036 1037 def _setup_steps_per_run_test(self, 1038 every_n_steps, 1039 steps_per_run, 1040 graph, 1041 sess): 1042 variables.get_or_create_global_step() 1043 self.train_op = training_util._increment_global_step(steps_per_run) 1044 self.summary_writer = fake_summary_writer.FakeSummaryWriter( 1045 self.log_dir, graph) 1046 self.hook = basic_session_run_hooks.StepCounterHook( 1047 summary_writer=self.summary_writer, every_n_steps=every_n_steps) 1048 self.hook._set_steps_per_run(steps_per_run) 1049 self.hook.begin() 1050 self.evaluate(variables_lib.global_variables_initializer()) 1051 self.mon_sess = monitored_session._HookedSession(sess, [self.hook]) 1052 1053 @test.mock.patch.object(time, 'time') 1054 def test_steps_per_run_less_than_every_n_steps(self, mock_time): 1055 mock_time.return_value = MOCK_START_TIME 1056 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1057 self._setup_steps_per_run_test(10, 5, g, sess) 1058 1059 # Logs at 15, 25 1060 for _ in range(5): 1061 mock_time.return_value += 0.01 1062 self.mon_sess.run(self.train_op) 1063 1064 self.hook.end(sess) 1065 self.summary_writer.assert_summaries( 1066 test_case=self, 1067 expected_logdir=self.log_dir, 1068 expected_graph=g, 1069 expected_summaries={}) 1070 self.assertItemsEqual([15, 25], self.summary_writer.summaries.keys()) 1071 for step in [15, 25]: 1072 summary_value = self.summary_writer.summaries[step][0].value[0] 1073 self.assertEqual('global_step/sec', summary_value.tag) 1074 self.assertGreater(summary_value.simple_value, 0) 1075 1076 @test.mock.patch.object(time, 'time') 1077 def test_steps_per_run_equal_every_n_steps(self, mock_time): 1078 mock_time.return_value = MOCK_START_TIME 1079 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1080 self._setup_steps_per_run_test(5, 5, g, sess) 1081 1082 # Logs at 10, 15, 20, 25 1083 for _ in range(5): 1084 mock_time.return_value += 0.01 1085 self.mon_sess.run(self.train_op) 1086 1087 self.hook.end(sess) 1088 self.summary_writer.assert_summaries( 1089 test_case=self, 1090 expected_logdir=self.log_dir, 1091 expected_graph=g, 1092 expected_summaries={}) 1093 self.assertItemsEqual([10, 15, 20, 25], 1094 self.summary_writer.summaries.keys()) 1095 for step in [10, 15, 20, 25]: 1096 summary_value = self.summary_writer.summaries[step][0].value[0] 1097 self.assertEqual('global_step/sec', summary_value.tag) 1098 self.assertGreater(summary_value.simple_value, 0) 1099 1100 @test.mock.patch.object(time, 'time') 1101 def test_steps_per_run_greater_than_every_n_steps(self, mock_time): 1102 mock_time.return_value = MOCK_START_TIME 1103 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1104 self._setup_steps_per_run_test(5, 10, g, sess) 1105 1106 # Logs at 20, 30, 40, 50 1107 for _ in range(5): 1108 mock_time.return_value += 0.01 1109 self.mon_sess.run(self.train_op) 1110 1111 self.hook.end(sess) 1112 self.summary_writer.assert_summaries( 1113 test_case=self, 1114 expected_logdir=self.log_dir, 1115 expected_graph=g, 1116 expected_summaries={}) 1117 self.assertItemsEqual([20, 30, 40, 50], 1118 self.summary_writer.summaries.keys()) 1119 for step in [20, 30, 40, 50]: 1120 summary_value = self.summary_writer.summaries[step][0].value[0] 1121 self.assertEqual('global_step/sec', summary_value.tag) 1122 self.assertGreater(summary_value.simple_value, 0) 1123 1124 1125@test_util.run_deprecated_v1 1126class SummarySaverHookTest(test.TestCase): 1127 1128 def setUp(self): 1129 test.TestCase.setUp(self) 1130 1131 self.log_dir = 'log/dir' 1132 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir) 1133 1134 var = variables_lib.Variable(0.0) 1135 tensor = state_ops.assign_add(var, 1.0) 1136 tensor2 = tensor * 2 1137 self.summary_op = summary_lib.scalar('my_summary', tensor) 1138 self.summary_op2 = summary_lib.scalar('my_summary2', tensor2) 1139 1140 variables.get_or_create_global_step() 1141 self.train_op = training_util._increment_global_step(1) 1142 1143 def test_raise_when_scaffold_and_summary_op_both_missing(self): 1144 with self.assertRaises(ValueError): 1145 basic_session_run_hooks.SummarySaverHook() 1146 1147 def test_raise_when_scaffold_and_summary_op_both_present(self): 1148 with self.assertRaises(ValueError): 1149 basic_session_run_hooks.SummarySaverHook( 1150 scaffold=monitored_session.Scaffold(), summary_op=self.summary_op) 1151 1152 def test_raise_in_both_secs_and_steps(self): 1153 with self.assertRaises(ValueError): 1154 basic_session_run_hooks.SummarySaverHook( 1155 save_secs=10, save_steps=20, summary_writer=self.summary_writer) 1156 1157 def test_raise_in_none_secs_and_steps(self): 1158 with self.assertRaises(ValueError): 1159 basic_session_run_hooks.SummarySaverHook( 1160 save_secs=None, save_steps=None, summary_writer=self.summary_writer) 1161 1162 def test_save_steps(self): 1163 hook = basic_session_run_hooks.SummarySaverHook( 1164 save_steps=8, 1165 summary_writer=self.summary_writer, 1166 summary_op=self.summary_op) 1167 1168 with self.cached_session() as sess: 1169 hook.begin() 1170 self.evaluate(variables_lib.global_variables_initializer()) 1171 mon_sess = monitored_session._HookedSession(sess, [hook]) 1172 for _ in range(30): 1173 mon_sess.run(self.train_op) 1174 hook.end(sess) 1175 1176 self.summary_writer.assert_summaries( 1177 test_case=self, 1178 expected_logdir=self.log_dir, 1179 expected_summaries={ 1180 1: { 1181 'my_summary': 1.0 1182 }, 1183 9: { 1184 'my_summary': 2.0 1185 }, 1186 17: { 1187 'my_summary': 3.0 1188 }, 1189 25: { 1190 'my_summary': 4.0 1191 }, 1192 }) 1193 1194 def test_multiple_summaries(self): 1195 hook = basic_session_run_hooks.SummarySaverHook( 1196 save_steps=8, 1197 summary_writer=self.summary_writer, 1198 summary_op=[self.summary_op, self.summary_op2]) 1199 1200 with self.cached_session() as sess: 1201 hook.begin() 1202 self.evaluate(variables_lib.global_variables_initializer()) 1203 mon_sess = monitored_session._HookedSession(sess, [hook]) 1204 for _ in range(10): 1205 mon_sess.run(self.train_op) 1206 hook.end(sess) 1207 1208 self.summary_writer.assert_summaries( 1209 test_case=self, 1210 expected_logdir=self.log_dir, 1211 expected_summaries={ 1212 1: { 1213 'my_summary': 1.0, 1214 'my_summary2': 2.0 1215 }, 1216 9: { 1217 'my_summary': 2.0, 1218 'my_summary2': 4.0 1219 }, 1220 }) 1221 1222 @test.mock.patch.object(time, 'time') 1223 def test_save_secs_saving_once_every_step(self, mock_time): 1224 mock_time.return_value = MOCK_START_TIME 1225 hook = basic_session_run_hooks.SummarySaverHook( 1226 save_secs=0.5, 1227 summary_writer=self.summary_writer, 1228 summary_op=self.summary_op) 1229 1230 with self.cached_session() as sess: 1231 hook.begin() 1232 self.evaluate(variables_lib.global_variables_initializer()) 1233 mon_sess = monitored_session._HookedSession(sess, [hook]) 1234 for _ in range(4): 1235 mon_sess.run(self.train_op) 1236 mock_time.return_value += 0.5 1237 hook.end(sess) 1238 1239 self.summary_writer.assert_summaries( 1240 test_case=self, 1241 expected_logdir=self.log_dir, 1242 expected_summaries={ 1243 1: { 1244 'my_summary': 1.0 1245 }, 1246 2: { 1247 'my_summary': 2.0 1248 }, 1249 3: { 1250 'my_summary': 3.0 1251 }, 1252 4: { 1253 'my_summary': 4.0 1254 }, 1255 }) 1256 1257 @test.mock.patch.object(time, 'time') 1258 def test_save_secs_saving_once_every_three_steps(self, mock_time): 1259 mock_time.return_value = 1484695987.209386 1260 hook = basic_session_run_hooks.SummarySaverHook( 1261 save_secs=9., 1262 summary_writer=self.summary_writer, 1263 summary_op=self.summary_op) 1264 1265 with self.cached_session() as sess: 1266 hook.begin() 1267 self.evaluate(variables_lib.global_variables_initializer()) 1268 mon_sess = monitored_session._HookedSession(sess, [hook]) 1269 for _ in range(8): 1270 mon_sess.run(self.train_op) 1271 mock_time.return_value += 3.1 1272 hook.end(sess) 1273 1274 # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first: 1275 self.summary_writer.assert_summaries( 1276 test_case=self, 1277 expected_logdir=self.log_dir, 1278 expected_summaries={ 1279 1: { 1280 'my_summary': 1.0 1281 }, 1282 4: { 1283 'my_summary': 2.0 1284 }, 1285 7: { 1286 'my_summary': 3.0 1287 }, 1288 }) 1289 1290 1291class GlobalStepWaiterHookTest(test.TestCase): 1292 1293 def test_not_wait_for_step_zero(self): 1294 with ops.Graph().as_default(): 1295 variables.get_or_create_global_step() 1296 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0) 1297 hook.begin() 1298 with session_lib.Session() as sess: 1299 # Before run should return without waiting gstep increment. 1300 hook.before_run( 1301 session_run_hook.SessionRunContext( 1302 original_args=None, session=sess)) 1303 1304 @test.mock.patch.object(time, 'sleep') 1305 def test_wait_for_step(self, mock_sleep): 1306 with ops.Graph().as_default(): 1307 gstep = variables.get_or_create_global_step() 1308 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000) 1309 hook.begin() 1310 1311 with session_lib.Session() as sess: 1312 # Mock out calls to time.sleep() to update the global step. 1313 1314 class Context(object): 1315 counter = 0 1316 1317 def mock_sleep_side_effect(seconds): 1318 del seconds # argument is ignored 1319 Context.counter += 1 1320 if Context.counter == 1: 1321 # The first time sleep() is called, we update the global_step from 1322 # 0 to 500. 1323 sess.run(state_ops.assign(gstep, 500)) 1324 elif Context.counter == 2: 1325 # The second time sleep() is called, we update the global_step from 1326 # 500 to 1100. 1327 sess.run(state_ops.assign(gstep, 1100)) 1328 else: 1329 raise AssertionError( 1330 'Expected before_run() to terminate after the second call to ' 1331 'time.sleep()') 1332 1333 mock_sleep.side_effect = mock_sleep_side_effect 1334 1335 # Run the mocked-out interaction with the hook. 1336 self.evaluate(variables_lib.global_variables_initializer()) 1337 run_context = session_run_hook.SessionRunContext( 1338 original_args=None, session=sess) 1339 hook.before_run(run_context) 1340 self.assertEqual(Context.counter, 2) 1341 1342 1343class FinalOpsHookTest(test.TestCase): 1344 1345 def test_final_ops_is_scalar_tensor(self): 1346 with ops.Graph().as_default(): 1347 expected_value = 4 1348 final_ops = constant_op.constant(expected_value) 1349 1350 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1351 hook.begin() 1352 1353 with session_lib.Session() as session: 1354 hook.end(session) 1355 self.assertEqual(expected_value, 1356 hook.final_ops_values) 1357 1358 def test_final_ops_is_tensor(self): 1359 with ops.Graph().as_default(): 1360 expected_values = [1, 6, 3, 5, 2, 4] 1361 final_ops = constant_op.constant(expected_values) 1362 1363 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1364 hook.begin() 1365 1366 with session_lib.Session() as session: 1367 hook.end(session) 1368 self.assertListEqual(expected_values, 1369 hook.final_ops_values.tolist()) 1370 1371 def test_final_ops_triggers_out_of_range_error(self): 1372 with ops.Graph().as_default(): 1373 dataset = dataset_ops.Dataset.range(1) 1374 iterator = dataset_ops.make_one_shot_iterator(dataset) 1375 read_ops = iterator.get_next() 1376 final_ops = read_ops 1377 1378 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1379 hook.begin() 1380 1381 with session_lib.Session() as session: 1382 session.run(read_ops) 1383 with test.mock.patch.object(tf_logging, 'warning') as mock_log: 1384 with self.assertRaisesRegexp(errors.OutOfRangeError, 1385 'End of sequence'): 1386 hook.end(session) 1387 self.assertRegexpMatches( 1388 str(mock_log.call_args), 1389 'dependency back to some input source') 1390 1391 def test_final_ops_with_dictionary(self): 1392 with ops.Graph().as_default(): 1393 expected_values = [4, -3] 1394 final_ops = array_ops.placeholder(dtype=dtypes.float32) 1395 final_ops_feed_dict = {final_ops: expected_values} 1396 1397 hook = basic_session_run_hooks.FinalOpsHook( 1398 final_ops, final_ops_feed_dict) 1399 hook.begin() 1400 1401 with session_lib.Session() as session: 1402 hook.end(session) 1403 self.assertListEqual(expected_values, 1404 hook.final_ops_values.tolist()) 1405 1406 1407@test_util.run_deprecated_v1 1408class ResourceSummarySaverHookTest(test.TestCase): 1409 1410 def setUp(self): 1411 test.TestCase.setUp(self) 1412 1413 self.log_dir = 'log/dir' 1414 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir) 1415 1416 var = variable_scope.get_variable('var', initializer=0.0, use_resource=True) 1417 tensor = state_ops.assign_add(var, 1.0) 1418 self.summary_op = summary_lib.scalar('my_summary', tensor) 1419 1420 with variable_scope.variable_scope('foo', use_resource=True): 1421 variables.create_global_step() 1422 self.train_op = training_util._increment_global_step(1) 1423 1424 def test_save_steps(self): 1425 hook = basic_session_run_hooks.SummarySaverHook( 1426 save_steps=8, 1427 summary_writer=self.summary_writer, 1428 summary_op=self.summary_op) 1429 1430 with self.cached_session() as sess: 1431 hook.begin() 1432 self.evaluate(variables_lib.global_variables_initializer()) 1433 mon_sess = monitored_session._HookedSession(sess, [hook]) 1434 for _ in range(30): 1435 mon_sess.run(self.train_op) 1436 hook.end(sess) 1437 1438 self.summary_writer.assert_summaries( 1439 test_case=self, 1440 expected_logdir=self.log_dir, 1441 expected_summaries={ 1442 1: { 1443 'my_summary': 1.0 1444 }, 1445 9: { 1446 'my_summary': 2.0 1447 }, 1448 17: { 1449 'my_summary': 3.0 1450 }, 1451 25: { 1452 'my_summary': 4.0 1453 }, 1454 }) 1455 1456 1457class FeedFnHookTest(test.TestCase): 1458 1459 def test_feeding_placeholder(self): 1460 with ops.Graph().as_default(), session_lib.Session() as sess: 1461 x = array_ops.placeholder(dtype=dtypes.float32) 1462 y = x + 1 1463 hook = basic_session_run_hooks.FeedFnHook( 1464 feed_fn=lambda: {x: 1.0}) 1465 hook.begin() 1466 mon_sess = monitored_session._HookedSession(sess, [hook]) 1467 self.assertEqual(mon_sess.run(y), 2) 1468 1469 1470class ProfilerHookTest(test.TestCase): 1471 1472 def setUp(self): 1473 super(ProfilerHookTest, self).setUp() 1474 self.output_dir = tempfile.mkdtemp() 1475 self.graph = ops.Graph() 1476 self.filepattern = os.path.join(self.output_dir, 'timeline-*.json') 1477 with self.graph.as_default(): 1478 self.global_step = variables.get_or_create_global_step() 1479 self.train_op = state_ops.assign_add(self.global_step, 1) 1480 1481 def tearDown(self): 1482 super(ProfilerHookTest, self).tearDown() 1483 shutil.rmtree(self.output_dir, ignore_errors=True) 1484 1485 def _count_timeline_files(self): 1486 return len(gfile.Glob(self.filepattern)) 1487 1488 @test_util.run_deprecated_v1 1489 def test_raise_in_both_secs_and_steps(self): 1490 with self.assertRaises(ValueError): 1491 basic_session_run_hooks.ProfilerHook(save_secs=10, save_steps=20) 1492 1493 @test_util.run_deprecated_v1 1494 def test_raise_in_none_secs_and_steps(self): 1495 with self.assertRaises(ValueError): 1496 basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None) 1497 1498 def test_save_secs_does_not_save_in_first_step(self): 1499 with self.graph.as_default(): 1500 hook = basic_session_run_hooks.ProfilerHook( 1501 save_secs=2, output_dir=self.output_dir) 1502 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1503 sess.run(self.train_op) 1504 self.assertEqual(0, self._count_timeline_files()) 1505 1506 @test.mock.patch.object(time, 'time') 1507 def test_save_secs_saves_periodically(self, mock_time): 1508 # Pick a fixed start time. 1509 with self.graph.as_default(): 1510 mock_time.return_value = MOCK_START_TIME 1511 hook = basic_session_run_hooks.ProfilerHook( 1512 save_secs=2, output_dir=self.output_dir) 1513 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1514 sess.run(self.train_op) # Not saved. 1515 self.assertEqual(0, self._count_timeline_files()) 1516 # Simulate 2.5 seconds of sleep. 1517 mock_time.return_value = MOCK_START_TIME + 2.5 1518 sess.run(self.train_op) # Saved. 1519 self.assertEqual(1, self._count_timeline_files()) 1520 1521 # Pretend some small amount of time has passed. 1522 mock_time.return_value = MOCK_START_TIME + 2.6 1523 sess.run(self.train_op) # Not saved. 1524 # Edge test just before we should save the timeline. 1525 mock_time.return_value = MOCK_START_TIME + 4.4 1526 sess.run(self.train_op) # Not saved. 1527 self.assertEqual(1, self._count_timeline_files()) 1528 1529 mock_time.return_value = MOCK_START_TIME + 4.5 1530 sess.run(self.train_op) # Saved. 1531 self.assertEqual(2, self._count_timeline_files()) 1532 1533 def test_save_steps_does_not_save_in_first_step(self): 1534 with self.graph.as_default(): 1535 hook = basic_session_run_hooks.ProfilerHook( 1536 save_steps=1, output_dir=self.output_dir) 1537 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1538 sess.run(self.train_op) # Not saved. 1539 self.assertEqual(0, self._count_timeline_files()) 1540 1541 def test_save_steps_saves_periodically(self): 1542 with self.graph.as_default(): 1543 hook = basic_session_run_hooks.ProfilerHook( 1544 save_steps=2, output_dir=self.output_dir) 1545 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1546 self.assertEqual(0, self._count_timeline_files()) 1547 sess.run(self.train_op) # Not saved. 1548 self.assertEqual(0, self._count_timeline_files()) 1549 sess.run(self.train_op) # Saved. 1550 self.assertEqual(1, self._count_timeline_files()) 1551 sess.run(self.train_op) # Not saved. 1552 self.assertEqual(1, self._count_timeline_files()) 1553 sess.run(self.train_op) # Saved. 1554 self.assertEqual(2, self._count_timeline_files()) 1555 sess.run(self.train_op) # Not saved. 1556 self.assertEqual(2, self._count_timeline_files()) 1557 1558 def test_run_metadata_saves(self): 1559 writer_cache.FileWriterCache.clear() 1560 fake_summary_writer.FakeSummaryWriter.install() 1561 fake_writer = writer_cache.FileWriterCache.get(self.output_dir) 1562 with self.graph.as_default(): 1563 hook = basic_session_run_hooks.ProfilerHook( 1564 save_steps=1, output_dir=self.output_dir) 1565 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1566 sess.run(self.train_op) # Not saved. 1567 sess.run(self.train_op) # Saved. 1568 self.assertEqual( 1569 list(fake_writer._added_run_metadata.keys()), ['step_2']) 1570 fake_summary_writer.FakeSummaryWriter.uninstall() 1571 1572 1573if __name__ == '__main__': 1574 test.main() 1575