1# pylint: disable=g-bad-file-header 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for basic_session_run_hooks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os.path 23import shutil 24import tempfile 25import time 26 27from tensorflow.python.client import session as session_lib 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import meta_graph 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.ops import variables as variables_lib 40import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 41from tensorflow.python.platform import gfile 42from tensorflow.python.platform import test 43from tensorflow.python.platform import tf_logging 44from tensorflow.python.summary import summary as summary_lib 45from tensorflow.python.summary.writer import fake_summary_writer 46from tensorflow.python.summary.writer import writer_cache 47from tensorflow.python.training import basic_session_run_hooks 48from tensorflow.python.training import checkpoint_utils 49from tensorflow.python.training import monitored_session 50from tensorflow.python.training import session_run_hook 51from tensorflow.python.training import training_util 52 53 54# Provide a realistic start time for unit tests where we need to mock out 55# calls to time.time(). 56MOCK_START_TIME = 1484695987.209386 57 58 59class MockCheckpointSaverListener( 60 basic_session_run_hooks.CheckpointSaverListener): 61 62 def __init__(self): 63 self.begin_count = 0 64 self.before_save_count = 0 65 self.after_save_count = 0 66 self.end_count = 0 67 self.ask_for_stop = False 68 69 def begin(self): 70 self.begin_count += 1 71 72 def before_save(self, session, global_step): 73 self.before_save_count += 1 74 75 def after_save(self, session, global_step): 76 self.after_save_count += 1 77 if self.ask_for_stop: 78 return True 79 80 def end(self, session, global_step): 81 self.end_count += 1 82 83 def get_counts(self): 84 return { 85 'begin': self.begin_count, 86 'before_save': self.before_save_count, 87 'after_save': self.after_save_count, 88 'end': self.end_count 89 } 90 91 92class SecondOrStepTimerTest(test.TestCase): 93 94 @test_util.run_deprecated_v1 95 def test_raise_in_both_secs_and_steps(self): 96 with self.assertRaises(ValueError): 97 basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10) 98 99 @test_util.run_deprecated_v1 100 def test_raise_in_none_secs_and_steps(self): 101 with self.assertRaises(ValueError): 102 basic_session_run_hooks.SecondOrStepTimer() 103 104 @test.mock.patch.object(time, 'time') 105 def test_every_secs(self, mock_time): 106 mock_time.return_value = MOCK_START_TIME 107 timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0) 108 self.assertTrue(timer.should_trigger_for_step(1)) 109 110 timer.update_last_triggered_step(1) 111 self.assertFalse(timer.should_trigger_for_step(1)) 112 self.assertFalse(timer.should_trigger_for_step(2)) 113 114 mock_time.return_value += 1.0 115 self.assertFalse(timer.should_trigger_for_step(1)) 116 self.assertTrue(timer.should_trigger_for_step(2)) 117 118 def test_every_steps(self): 119 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3) 120 self.assertTrue(timer.should_trigger_for_step(1)) 121 122 timer.update_last_triggered_step(1) 123 self.assertFalse(timer.should_trigger_for_step(1)) 124 self.assertFalse(timer.should_trigger_for_step(2)) 125 self.assertFalse(timer.should_trigger_for_step(3)) 126 self.assertTrue(timer.should_trigger_for_step(4)) 127 128 def test_update_last_triggered_step(self): 129 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1) 130 131 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1) 132 self.assertEqual(None, elapsed_secs) 133 self.assertEqual(None, elapsed_steps) 134 135 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5) 136 self.assertLess(0, elapsed_secs) 137 self.assertEqual(4, elapsed_steps) 138 139 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7) 140 self.assertLess(0, elapsed_secs) 141 self.assertEqual(2, elapsed_steps) 142 143 144class StopAtStepTest(test.TestCase): 145 146 def test_raise_in_both_last_step_and_num_steps(self): 147 with self.assertRaises(ValueError): 148 basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20) 149 150 def test_stop_based_on_last_step(self): 151 h = basic_session_run_hooks.StopAtStepHook(last_step=10) 152 with ops.Graph().as_default(): 153 global_step = training_util.get_or_create_global_step() 154 no_op = control_flow_ops.no_op() 155 h.begin() 156 with session_lib.Session() as sess: 157 mon_sess = monitored_session._HookedSession(sess, [h]) 158 sess.run(state_ops.assign(global_step, 5)) 159 h.after_create_session(sess, None) 160 mon_sess.run(no_op) 161 self.assertFalse(mon_sess.should_stop()) 162 sess.run(state_ops.assign(global_step, 9)) 163 mon_sess.run(no_op) 164 self.assertFalse(mon_sess.should_stop()) 165 sess.run(state_ops.assign(global_step, 10)) 166 mon_sess.run(no_op) 167 self.assertTrue(mon_sess.should_stop()) 168 sess.run(state_ops.assign(global_step, 11)) 169 mon_sess._should_stop = False 170 mon_sess.run(no_op) 171 self.assertTrue(mon_sess.should_stop()) 172 173 def test_stop_based_on_num_step(self): 174 h = basic_session_run_hooks.StopAtStepHook(num_steps=10) 175 176 with ops.Graph().as_default(): 177 global_step = training_util.get_or_create_global_step() 178 no_op = control_flow_ops.no_op() 179 h.begin() 180 with session_lib.Session() as sess: 181 mon_sess = monitored_session._HookedSession(sess, [h]) 182 sess.run(state_ops.assign(global_step, 5)) 183 h.after_create_session(sess, None) 184 mon_sess.run(no_op) 185 self.assertFalse(mon_sess.should_stop()) 186 sess.run(state_ops.assign(global_step, 13)) 187 mon_sess.run(no_op) 188 self.assertFalse(mon_sess.should_stop()) 189 sess.run(state_ops.assign(global_step, 14)) 190 mon_sess.run(no_op) 191 self.assertFalse(mon_sess.should_stop()) 192 sess.run(state_ops.assign(global_step, 15)) 193 mon_sess.run(no_op) 194 self.assertTrue(mon_sess.should_stop()) 195 sess.run(state_ops.assign(global_step, 16)) 196 mon_sess._should_stop = False 197 mon_sess.run(no_op) 198 self.assertTrue(mon_sess.should_stop()) 199 200 def test_stop_based_with_multiple_steps(self): 201 h = basic_session_run_hooks.StopAtStepHook(num_steps=10) 202 203 with ops.Graph().as_default(): 204 global_step = training_util.get_or_create_global_step() 205 no_op = control_flow_ops.no_op() 206 h.begin() 207 with session_lib.Session() as sess: 208 mon_sess = monitored_session._HookedSession(sess, [h]) 209 sess.run(state_ops.assign(global_step, 5)) 210 h.after_create_session(sess, None) 211 mon_sess.run(no_op) 212 self.assertFalse(mon_sess.should_stop()) 213 sess.run(state_ops.assign(global_step, 15)) 214 mon_sess.run(no_op) 215 self.assertTrue(mon_sess.should_stop()) 216 217 218class LoggingTensorHookTest(test.TestCase): 219 220 def setUp(self): 221 # Mock out logging calls so we can verify whether correct tensors are being 222 # monitored. 223 self._actual_log = tf_logging.info 224 self.logged_message = None 225 226 def mock_log(*args, **kwargs): 227 self.logged_message = args 228 self._actual_log(*args, **kwargs) 229 230 tf_logging.info = mock_log 231 232 def tearDown(self): 233 tf_logging.info = self._actual_log 234 235 def test_illegal_args(self): 236 with self.assertRaisesRegex(ValueError, 'nvalid every_n_iter'): 237 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0) 238 with self.assertRaisesRegex(ValueError, 'nvalid every_n_iter'): 239 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10) 240 with self.assertRaisesRegex(ValueError, 'xactly one of'): 241 basic_session_run_hooks.LoggingTensorHook( 242 tensors=['t'], every_n_iter=5, every_n_secs=5) 243 with self.assertRaisesRegex(ValueError, 'xactly one of'): 244 basic_session_run_hooks.LoggingTensorHook(tensors=['t']) 245 246 def test_print_at_end_only(self): 247 with ops.Graph().as_default(), session_lib.Session() as sess: 248 t = constant_op.constant(42.0, name='foo') 249 train_op = constant_op.constant(3) 250 hook = basic_session_run_hooks.LoggingTensorHook( 251 tensors=[t.name], at_end=True) 252 hook.begin() 253 mon_sess = monitored_session._HookedSession(sess, [hook]) 254 self.evaluate(variables_lib.global_variables_initializer()) 255 self.logged_message = '' 256 for _ in range(3): 257 mon_sess.run(train_op) 258 # assertNotRegexpMatches is not supported by python 3.1 and later 259 self.assertEqual(str(self.logged_message).find(t.name), -1) 260 261 hook.end(sess) 262 self.assertRegex(str(self.logged_message), t.name) 263 264 def _validate_print_every_n_steps(self, sess, at_end): 265 t = constant_op.constant(42.0, name='foo') 266 267 train_op = constant_op.constant(3) 268 hook = basic_session_run_hooks.LoggingTensorHook( 269 tensors=[t.name], every_n_iter=10, at_end=at_end) 270 hook.begin() 271 mon_sess = monitored_session._HookedSession(sess, [hook]) 272 self.evaluate(variables_lib.global_variables_initializer()) 273 mon_sess.run(train_op) 274 self.assertRegex(str(self.logged_message), t.name) 275 for _ in range(3): 276 self.logged_message = '' 277 for _ in range(9): 278 mon_sess.run(train_op) 279 # assertNotRegexpMatches is not supported by python 3.1 and later 280 self.assertEqual(str(self.logged_message).find(t.name), -1) 281 mon_sess.run(train_op) 282 self.assertRegex(str(self.logged_message), t.name) 283 284 # Add additional run to verify proper reset when called multiple times. 285 self.logged_message = '' 286 mon_sess.run(train_op) 287 # assertNotRegexpMatches is not supported by python 3.1 and later 288 self.assertEqual(str(self.logged_message).find(t.name), -1) 289 290 self.logged_message = '' 291 hook.end(sess) 292 if at_end: 293 self.assertRegex(str(self.logged_message), t.name) 294 else: 295 # assertNotRegexpMatches is not supported by python 3.1 and later 296 self.assertEqual(str(self.logged_message).find(t.name), -1) 297 298 def test_print_every_n_steps(self): 299 with ops.Graph().as_default(), session_lib.Session() as sess: 300 self._validate_print_every_n_steps(sess, at_end=False) 301 # Verify proper reset. 302 self._validate_print_every_n_steps(sess, at_end=False) 303 304 def test_print_every_n_steps_and_end(self): 305 with ops.Graph().as_default(), session_lib.Session() as sess: 306 self._validate_print_every_n_steps(sess, at_end=True) 307 # Verify proper reset. 308 self._validate_print_every_n_steps(sess, at_end=True) 309 310 def test_print_first_step(self): 311 # if it runs every iteration, first iteration has None duration. 312 with ops.Graph().as_default(), session_lib.Session() as sess: 313 t = constant_op.constant(42.0, name='foo') 314 train_op = constant_op.constant(3) 315 hook = basic_session_run_hooks.LoggingTensorHook( 316 tensors={'foo': t}, every_n_iter=1) 317 hook.begin() 318 mon_sess = monitored_session._HookedSession(sess, [hook]) 319 self.evaluate(variables_lib.global_variables_initializer()) 320 mon_sess.run(train_op) 321 self.assertRegex(str(self.logged_message), 'foo') 322 # in first run, elapsed time is None. 323 self.assertEqual(str(self.logged_message).find('sec'), -1) 324 325 def _validate_print_every_n_secs(self, sess, at_end, mock_time): 326 t = constant_op.constant(42.0, name='foo') 327 train_op = constant_op.constant(3) 328 329 hook = basic_session_run_hooks.LoggingTensorHook( 330 tensors=[t.name], every_n_secs=1.0, at_end=at_end) 331 hook.begin() 332 mon_sess = monitored_session._HookedSession(sess, [hook]) 333 self.evaluate(variables_lib.global_variables_initializer()) 334 335 mon_sess.run(train_op) 336 self.assertRegex(str(self.logged_message), t.name) 337 338 # assertNotRegexpMatches is not supported by python 3.1 and later 339 self.logged_message = '' 340 mon_sess.run(train_op) 341 self.assertEqual(str(self.logged_message).find(t.name), -1) 342 mock_time.return_value += 1.0 343 344 self.logged_message = '' 345 mon_sess.run(train_op) 346 self.assertRegex(str(self.logged_message), t.name) 347 348 self.logged_message = '' 349 hook.end(sess) 350 if at_end: 351 self.assertRegex(str(self.logged_message), t.name) 352 else: 353 # assertNotRegexpMatches is not supported by python 3.1 and later 354 self.assertEqual(str(self.logged_message).find(t.name), -1) 355 356 @test.mock.patch.object(time, 'time') 357 def test_print_every_n_secs(self, mock_time): 358 with ops.Graph().as_default(), session_lib.Session() as sess: 359 mock_time.return_value = MOCK_START_TIME 360 self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time) 361 # Verify proper reset. 362 self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time) 363 364 @test.mock.patch.object(time, 'time') 365 def test_print_every_n_secs_and_end(self, mock_time): 366 with ops.Graph().as_default(), session_lib.Session() as sess: 367 mock_time.return_value = MOCK_START_TIME 368 self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time) 369 # Verify proper reset. 370 self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time) 371 372 def test_print_formatter(self): 373 with ops.Graph().as_default(), session_lib.Session() as sess: 374 t = constant_op.constant(42.0, name='foo') 375 train_op = constant_op.constant(3) 376 hook = basic_session_run_hooks.LoggingTensorHook( 377 tensors=[t.name], every_n_iter=10, 378 formatter=lambda items: 'qqq=%s' % items[t.name]) 379 hook.begin() 380 mon_sess = monitored_session._HookedSession(sess, [hook]) 381 self.evaluate(variables_lib.global_variables_initializer()) 382 mon_sess.run(train_op) 383 self.assertEqual(self.logged_message[0], 'qqq=42.0') 384 385 386class CheckpointSaverHookTest(test.TestCase): 387 388 def setUp(self): 389 self.model_dir = tempfile.mkdtemp() 390 self.graph = ops.Graph() 391 with self.graph.as_default(): 392 self.scaffold = monitored_session.Scaffold() 393 self.global_step = training_util.get_or_create_global_step() 394 self.train_op = training_util._increment_global_step(1) 395 396 def tearDown(self): 397 shutil.rmtree(self.model_dir, ignore_errors=True) 398 399 def test_saves_when_saver_and_scaffold_both_missing(self): 400 with self.graph.as_default(): 401 hook = basic_session_run_hooks.CheckpointSaverHook( 402 self.model_dir, save_steps=1) 403 hook.begin() 404 self.scaffold.finalize() 405 with session_lib.Session() as sess: 406 sess.run(self.scaffold.init_op) 407 mon_sess = monitored_session._HookedSession(sess, [hook]) 408 mon_sess.run(self.train_op) 409 self.assertEqual(1, 410 checkpoint_utils.load_variable(self.model_dir, 411 self.global_step.name)) 412 413 def test_raise_when_saver_and_scaffold_both_present(self): 414 with self.assertRaises(ValueError): 415 basic_session_run_hooks.CheckpointSaverHook( 416 self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold) 417 418 @test_util.run_deprecated_v1 419 def test_raise_in_both_secs_and_steps(self): 420 with self.assertRaises(ValueError): 421 basic_session_run_hooks.CheckpointSaverHook( 422 self.model_dir, save_secs=10, save_steps=20) 423 424 @test_util.run_deprecated_v1 425 def test_raise_in_none_secs_and_steps(self): 426 with self.assertRaises(ValueError): 427 basic_session_run_hooks.CheckpointSaverHook(self.model_dir) 428 429 def test_save_secs_saves_in_first_step(self): 430 with self.graph.as_default(): 431 hook = basic_session_run_hooks.CheckpointSaverHook( 432 self.model_dir, save_secs=2, scaffold=self.scaffold) 433 hook.begin() 434 self.scaffold.finalize() 435 with session_lib.Session() as sess: 436 sess.run(self.scaffold.init_op) 437 mon_sess = monitored_session._HookedSession(sess, [hook]) 438 mon_sess.run(self.train_op) 439 self.assertEqual(1, 440 checkpoint_utils.load_variable(self.model_dir, 441 self.global_step.name)) 442 443 def test_save_secs_calls_listeners_at_begin_and_end(self): 444 with self.graph.as_default(): 445 listener = MockCheckpointSaverListener() 446 hook = basic_session_run_hooks.CheckpointSaverHook( 447 self.model_dir, 448 save_secs=2, 449 scaffold=self.scaffold, 450 listeners=[listener]) 451 hook.begin() 452 self.scaffold.finalize() 453 with session_lib.Session() as sess: 454 sess.run(self.scaffold.init_op) 455 mon_sess = monitored_session._HookedSession(sess, [hook]) 456 mon_sess.run(self.train_op) # hook runs here 457 mon_sess.run(self.train_op) # hook won't run here, so it does at end 458 hook.end(sess) # hook runs here 459 self.assertEqual({ 460 'begin': 1, 461 'before_save': 2, 462 'after_save': 2, 463 'end': 1 464 }, listener.get_counts()) 465 466 def test_listener_with_monitored_session(self): 467 with ops.Graph().as_default(): 468 scaffold = monitored_session.Scaffold() 469 global_step = training_util.get_or_create_global_step() 470 train_op = training_util._increment_global_step(1) 471 listener = MockCheckpointSaverListener() 472 hook = basic_session_run_hooks.CheckpointSaverHook( 473 self.model_dir, 474 save_steps=1, 475 scaffold=scaffold, 476 listeners=[listener]) 477 with monitored_session.SingularMonitoredSession( 478 hooks=[hook], 479 scaffold=scaffold, 480 checkpoint_dir=self.model_dir) as sess: 481 sess.run(train_op) 482 sess.run(train_op) 483 global_step_val = sess.raw_session().run(global_step) 484 listener_counts = listener.get_counts() 485 self.assertEqual(2, global_step_val) 486 self.assertEqual({ 487 'begin': 1, 488 'before_save': 3, 489 'after_save': 3, 490 'end': 1 491 }, listener_counts) 492 493 def test_listener_stops_training_in_after_save(self): 494 with ops.Graph().as_default(): 495 scaffold = monitored_session.Scaffold() 496 training_util.get_or_create_global_step() 497 train_op = training_util._increment_global_step(1) 498 listener = MockCheckpointSaverListener() 499 hook = basic_session_run_hooks.CheckpointSaverHook( 500 self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener]) 501 with monitored_session.SingularMonitoredSession( 502 hooks=[hook], scaffold=scaffold, 503 checkpoint_dir=self.model_dir) as sess: 504 sess.run(train_op) 505 self.assertFalse(sess.should_stop()) 506 sess.run(train_op) 507 self.assertFalse(sess.should_stop()) 508 listener.ask_for_stop = True 509 sess.run(train_op) 510 self.assertTrue(sess.should_stop()) 511 512 def test_listener_with_default_saver(self): 513 with ops.Graph().as_default(): 514 global_step = training_util.get_or_create_global_step() 515 train_op = training_util._increment_global_step(1) 516 listener = MockCheckpointSaverListener() 517 hook = basic_session_run_hooks.CheckpointSaverHook( 518 self.model_dir, 519 save_steps=1, 520 listeners=[listener]) 521 with monitored_session.SingularMonitoredSession( 522 hooks=[hook], 523 checkpoint_dir=self.model_dir) as sess: 524 sess.run(train_op) 525 sess.run(train_op) 526 global_step_val = sess.raw_session().run(global_step) 527 listener_counts = listener.get_counts() 528 self.assertEqual(2, global_step_val) 529 self.assertEqual({ 530 'begin': 1, 531 'before_save': 3, 532 'after_save': 3, 533 'end': 1 534 }, listener_counts) 535 536 with ops.Graph().as_default(): 537 global_step = training_util.get_or_create_global_step() 538 with monitored_session.SingularMonitoredSession( 539 checkpoint_dir=self.model_dir) as sess2: 540 global_step_saved_val = sess2.run(global_step) 541 self.assertEqual(2, global_step_saved_val) 542 543 def test_two_listeners_with_default_saver(self): 544 with ops.Graph().as_default(): 545 global_step = training_util.get_or_create_global_step() 546 train_op = training_util._increment_global_step(1) 547 listener1 = MockCheckpointSaverListener() 548 listener2 = MockCheckpointSaverListener() 549 hook = basic_session_run_hooks.CheckpointSaverHook( 550 self.model_dir, 551 save_steps=1, 552 listeners=[listener1, listener2]) 553 with monitored_session.SingularMonitoredSession( 554 hooks=[hook], 555 checkpoint_dir=self.model_dir) as sess: 556 sess.run(train_op) 557 sess.run(train_op) 558 global_step_val = sess.raw_session().run(global_step) 559 listener1_counts = listener1.get_counts() 560 listener2_counts = listener2.get_counts() 561 self.assertEqual(2, global_step_val) 562 self.assertEqual({ 563 'begin': 1, 564 'before_save': 3, 565 'after_save': 3, 566 'end': 1 567 }, listener1_counts) 568 self.assertEqual(listener1_counts, listener2_counts) 569 570 with ops.Graph().as_default(): 571 global_step = training_util.get_or_create_global_step() 572 with monitored_session.SingularMonitoredSession( 573 checkpoint_dir=self.model_dir) as sess2: 574 global_step_saved_val = sess2.run(global_step) 575 self.assertEqual(2, global_step_saved_val) 576 577 @test.mock.patch.object(time, 'time') 578 def test_save_secs_saves_periodically(self, mock_time): 579 with self.graph.as_default(): 580 mock_time.return_value = MOCK_START_TIME 581 hook = basic_session_run_hooks.CheckpointSaverHook( 582 self.model_dir, save_secs=2, scaffold=self.scaffold) 583 hook.begin() 584 self.scaffold.finalize() 585 586 with session_lib.Session() as sess: 587 sess.run(self.scaffold.init_op) 588 mon_sess = monitored_session._HookedSession(sess, [hook]) 589 590 mock_time.return_value = MOCK_START_TIME 591 mon_sess.run(self.train_op) # Saved. 592 593 mock_time.return_value = MOCK_START_TIME + 0.5 594 mon_sess.run(self.train_op) # Not saved. 595 596 self.assertEqual(1, 597 checkpoint_utils.load_variable(self.model_dir, 598 self.global_step.name)) 599 600 # Simulate 2.5 seconds of sleep. 601 mock_time.return_value = MOCK_START_TIME + 2.5 602 mon_sess.run(self.train_op) # Saved. 603 604 mock_time.return_value = MOCK_START_TIME + 2.6 605 mon_sess.run(self.train_op) # Not saved. 606 607 mock_time.return_value = MOCK_START_TIME + 2.7 608 mon_sess.run(self.train_op) # Not saved. 609 610 self.assertEqual(3, 611 checkpoint_utils.load_variable(self.model_dir, 612 self.global_step.name)) 613 614 # Simulate 7.5 more seconds of sleep (10 seconds from start. 615 mock_time.return_value = MOCK_START_TIME + 10 616 mon_sess.run(self.train_op) # Saved. 617 self.assertEqual(6, 618 checkpoint_utils.load_variable(self.model_dir, 619 self.global_step.name)) 620 621 @test.mock.patch.object(time, 'time') 622 def test_save_secs_calls_listeners_periodically(self, mock_time): 623 with self.graph.as_default(): 624 mock_time.return_value = MOCK_START_TIME 625 listener = MockCheckpointSaverListener() 626 hook = basic_session_run_hooks.CheckpointSaverHook( 627 self.model_dir, 628 save_secs=2, 629 scaffold=self.scaffold, 630 listeners=[listener]) 631 hook.begin() 632 self.scaffold.finalize() 633 with session_lib.Session() as sess: 634 sess.run(self.scaffold.init_op) 635 mon_sess = monitored_session._HookedSession(sess, [hook]) 636 637 mock_time.return_value = MOCK_START_TIME + 0.5 638 mon_sess.run(self.train_op) # hook runs here 639 640 mock_time.return_value = MOCK_START_TIME + 0.5 641 mon_sess.run(self.train_op) 642 643 mock_time.return_value = MOCK_START_TIME + 3.0 644 mon_sess.run(self.train_op) # hook runs here 645 646 mock_time.return_value = MOCK_START_TIME + 3.5 647 mon_sess.run(self.train_op) 648 649 mock_time.return_value = MOCK_START_TIME + 4.0 650 mon_sess.run(self.train_op) 651 652 mock_time.return_value = MOCK_START_TIME + 6.5 653 mon_sess.run(self.train_op) # hook runs here 654 655 mock_time.return_value = MOCK_START_TIME + 7.0 656 mon_sess.run(self.train_op) # hook won't run here, so it does at end 657 658 mock_time.return_value = MOCK_START_TIME + 7.5 659 hook.end(sess) # hook runs here 660 self.assertEqual({ 661 'begin': 1, 662 'before_save': 4, 663 'after_save': 4, 664 'end': 1 665 }, listener.get_counts()) 666 667 def test_save_steps_saves_in_first_step(self): 668 with self.graph.as_default(): 669 hook = basic_session_run_hooks.CheckpointSaverHook( 670 self.model_dir, save_steps=2, scaffold=self.scaffold) 671 hook.begin() 672 self.scaffold.finalize() 673 with session_lib.Session() as sess: 674 sess.run(self.scaffold.init_op) 675 mon_sess = monitored_session._HookedSession(sess, [hook]) 676 mon_sess.run(self.train_op) 677 self.assertEqual(1, 678 checkpoint_utils.load_variable(self.model_dir, 679 self.global_step.name)) 680 681 def test_save_steps_saves_periodically(self): 682 with self.graph.as_default(): 683 hook = basic_session_run_hooks.CheckpointSaverHook( 684 self.model_dir, save_steps=2, scaffold=self.scaffold) 685 hook.begin() 686 self.scaffold.finalize() 687 with session_lib.Session() as sess: 688 sess.run(self.scaffold.init_op) 689 mon_sess = monitored_session._HookedSession(sess, [hook]) 690 mon_sess.run(self.train_op) 691 mon_sess.run(self.train_op) 692 # Not saved 693 self.assertEqual(1, 694 checkpoint_utils.load_variable(self.model_dir, 695 self.global_step.name)) 696 mon_sess.run(self.train_op) 697 # saved 698 self.assertEqual(3, 699 checkpoint_utils.load_variable(self.model_dir, 700 self.global_step.name)) 701 mon_sess.run(self.train_op) 702 # Not saved 703 self.assertEqual(3, 704 checkpoint_utils.load_variable(self.model_dir, 705 self.global_step.name)) 706 mon_sess.run(self.train_op) 707 # saved 708 self.assertEqual(5, 709 checkpoint_utils.load_variable(self.model_dir, 710 self.global_step.name)) 711 712 def test_save_saves_at_end(self): 713 with self.graph.as_default(): 714 hook = basic_session_run_hooks.CheckpointSaverHook( 715 self.model_dir, save_secs=2, scaffold=self.scaffold) 716 hook.begin() 717 self.scaffold.finalize() 718 with session_lib.Session() as sess: 719 sess.run(self.scaffold.init_op) 720 mon_sess = monitored_session._HookedSession(sess, [hook]) 721 mon_sess.run(self.train_op) 722 mon_sess.run(self.train_op) 723 hook.end(sess) 724 self.assertEqual(2, 725 checkpoint_utils.load_variable(self.model_dir, 726 self.global_step.name)) 727 728 def test_summary_writer_defs(self): 729 fake_summary_writer.FakeSummaryWriter.install() 730 writer_cache.FileWriterCache.clear() 731 summary_writer = writer_cache.FileWriterCache.get(self.model_dir) 732 733 with self.graph.as_default(): 734 hook = basic_session_run_hooks.CheckpointSaverHook( 735 self.model_dir, save_steps=2, scaffold=self.scaffold) 736 hook.begin() 737 self.scaffold.finalize() 738 with session_lib.Session() as sess: 739 sess.run(self.scaffold.init_op) 740 mon_sess = monitored_session._HookedSession(sess, [hook]) 741 hook.after_create_session(sess, None) 742 mon_sess.run(self.train_op) 743 summary_writer.assert_summaries( 744 test_case=self, 745 expected_logdir=self.model_dir, 746 expected_added_meta_graphs=[ 747 meta_graph.create_meta_graph_def( 748 graph_def=self.graph.as_graph_def(add_shapes=True), 749 saver_def=self.scaffold.saver.saver_def) 750 ]) 751 752 fake_summary_writer.FakeSummaryWriter.uninstall() 753 754 def test_save_checkpoint_before_first_train_step(self): 755 with self.graph.as_default(): 756 hook = basic_session_run_hooks.CheckpointSaverHook( 757 self.model_dir, save_steps=2, scaffold=self.scaffold) 758 hook.begin() 759 self.scaffold.finalize() 760 with session_lib.Session() as sess: 761 mon_sess = monitored_session._HookedSession(sess, [hook]) 762 sess.run(self.scaffold.init_op) 763 hook.after_create_session(sess, None) 764 # Verifies that checkpoint is saved at step 0. 765 self.assertEqual(0, 766 checkpoint_utils.load_variable(self.model_dir, 767 self.global_step.name)) 768 # Verifies that no checkpoint is saved after one training step. 769 mon_sess.run(self.train_op) 770 self.assertEqual(0, 771 checkpoint_utils.load_variable(self.model_dir, 772 self.global_step.name)) 773 # Verifies that checkpoint is saved after save_steps. 774 mon_sess.run(self.train_op) 775 self.assertEqual(2, 776 checkpoint_utils.load_variable(self.model_dir, 777 self.global_step.name)) 778 779 def test_save_graph_def(self): 780 with self.graph.as_default(): 781 hook = basic_session_run_hooks.CheckpointSaverHook( 782 self.model_dir, save_steps=1, scaffold=self.scaffold, 783 save_graph_def=True) 784 hook.begin() 785 self.scaffold.finalize() 786 with session_lib.Session() as sess: 787 sess.run(self.scaffold.init_op) 788 mon_sess = monitored_session._HookedSession(sess, [hook]) 789 sess.run(self.scaffold.init_op) 790 hook.after_create_session(sess, None) 791 792 self.assertIn('graph.pbtxt', os.listdir(self.model_dir)) 793 # Should have a single .meta file for step 0 794 self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 1) 795 796 mon_sess.run(self.train_op) 797 self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 2) 798 799 def test_save_graph_def_false(self): 800 with self.graph.as_default(): 801 hook = basic_session_run_hooks.CheckpointSaverHook( 802 self.model_dir, save_steps=1, scaffold=self.scaffold, 803 save_graph_def=False) 804 hook.begin() 805 self.scaffold.finalize() 806 with session_lib.Session() as sess: 807 sess.run(self.scaffold.init_op) 808 mon_sess = monitored_session._HookedSession(sess, [hook]) 809 sess.run(self.scaffold.init_op) 810 hook.after_create_session(sess, None) 811 812 self.assertNotIn('graph.pbtxt', os.listdir(self.model_dir)) 813 # Should have a single .meta file for step 0 814 self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta'))) 815 816 mon_sess.run(self.train_op) 817 self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta'))) 818 819 820 821 822class CheckpointSaverHookMultiStepTest(test.TestCase): 823 824 def setUp(self): 825 self.model_dir = tempfile.mkdtemp() 826 self.graph = ops.Graph() 827 self.steps_per_run = 5 828 with self.graph.as_default(): 829 self.scaffold = monitored_session.Scaffold() 830 self.global_step = training_util.get_or_create_global_step() 831 self.train_op = training_util._increment_global_step(self.steps_per_run) 832 833 def tearDown(self): 834 shutil.rmtree(self.model_dir, ignore_errors=True) 835 836 def test_save_steps_saves_in_first_step(self): 837 with self.graph.as_default(): 838 hook = basic_session_run_hooks.CheckpointSaverHook( 839 self.model_dir, 840 save_steps=2*self.steps_per_run, 841 scaffold=self.scaffold) 842 hook._set_steps_per_run(self.steps_per_run) 843 hook.begin() 844 self.scaffold.finalize() 845 with session_lib.Session() as sess: 846 sess.run(self.scaffold.init_op) 847 mon_sess = monitored_session._HookedSession(sess, [hook]) 848 mon_sess.run(self.train_op) 849 self.assertEqual(5, 850 checkpoint_utils.load_variable(self.model_dir, 851 self.global_step.name)) 852 853 def test_save_steps_saves_periodically(self): 854 with self.graph.as_default(): 855 hook = basic_session_run_hooks.CheckpointSaverHook( 856 self.model_dir, 857 save_steps=2*self.steps_per_run, 858 scaffold=self.scaffold) 859 hook._set_steps_per_run(self.steps_per_run) 860 hook.begin() 861 self.scaffold.finalize() 862 with session_lib.Session() as sess: 863 sess.run(self.scaffold.init_op) 864 mon_sess = monitored_session._HookedSession(sess, [hook]) 865 mon_sess.run(self.train_op) 866 # Saved (step=5) 867 self.assertEqual(5, 868 checkpoint_utils.load_variable(self.model_dir, 869 self.global_step.name)) 870 871 mon_sess.run(self.train_op) 872 # Not saved (step=10) 873 self.assertEqual(5, 874 checkpoint_utils.load_variable(self.model_dir, 875 self.global_step.name)) 876 877 mon_sess.run(self.train_op) 878 # Saved (step=15) 879 self.assertEqual(15, 880 checkpoint_utils.load_variable(self.model_dir, 881 self.global_step.name)) 882 883 mon_sess.run(self.train_op) 884 # Not saved (step=20) 885 self.assertEqual(15, 886 checkpoint_utils.load_variable(self.model_dir, 887 self.global_step.name)) 888 889 mon_sess.run(self.train_op) 890 # Saved (step=25) 891 self.assertEqual(25, 892 checkpoint_utils.load_variable(self.model_dir, 893 self.global_step.name)) 894 895 def test_save_steps_saves_at_end(self): 896 with self.graph.as_default(): 897 hook = basic_session_run_hooks.CheckpointSaverHook( 898 self.model_dir, 899 save_steps=2*self.steps_per_run, 900 scaffold=self.scaffold) 901 hook._set_steps_per_run(self.steps_per_run) 902 hook.begin() 903 self.scaffold.finalize() 904 with session_lib.Session() as sess: 905 sess.run(self.scaffold.init_op) 906 mon_sess = monitored_session._HookedSession(sess, [hook]) 907 mon_sess.run(self.train_op) 908 mon_sess.run(self.train_op) 909 hook.end(sess) 910 self.assertEqual(10, 911 checkpoint_utils.load_variable(self.model_dir, 912 self.global_step.name)) 913 914 915class ResourceCheckpointSaverHookTest(test.TestCase): 916 917 def setUp(self): 918 self.model_dir = tempfile.mkdtemp() 919 self.graph = ops.Graph() 920 with self.graph.as_default(): 921 self.scaffold = monitored_session.Scaffold() 922 with variable_scope.variable_scope('foo', use_resource=True): 923 self.global_step = training_util.get_or_create_global_step() 924 self.train_op = training_util._increment_global_step(1) 925 926 def test_save_steps_saves_periodically(self): 927 with self.graph.as_default(): 928 hook = basic_session_run_hooks.CheckpointSaverHook( 929 self.model_dir, save_steps=2, scaffold=self.scaffold) 930 hook.begin() 931 self.scaffold.finalize() 932 with session_lib.Session() as sess: 933 sess.run(self.scaffold.init_op) 934 mon_sess = monitored_session._HookedSession(sess, [hook]) 935 mon_sess.run(self.train_op) 936 mon_sess.run(self.train_op) 937 # Not saved 938 self.assertEqual(1, 939 checkpoint_utils.load_variable(self.model_dir, 940 self.global_step.name)) 941 mon_sess.run(self.train_op) 942 # saved 943 self.assertEqual(3, 944 checkpoint_utils.load_variable(self.model_dir, 945 self.global_step.name)) 946 mon_sess.run(self.train_op) 947 # Not saved 948 self.assertEqual(3, 949 checkpoint_utils.load_variable(self.model_dir, 950 self.global_step.name)) 951 mon_sess.run(self.train_op) 952 # saved 953 self.assertEqual(5, 954 checkpoint_utils.load_variable(self.model_dir, 955 self.global_step.name)) 956 957 958class StepCounterHookTest(test.TestCase): 959 960 def setUp(self): 961 self.log_dir = tempfile.mkdtemp() 962 963 def tearDown(self): 964 shutil.rmtree(self.log_dir, ignore_errors=True) 965 966 @test.mock.patch.object(time, 'time') 967 def test_step_counter_every_n_steps(self, mock_time): 968 mock_time.return_value = MOCK_START_TIME 969 with ops.Graph().as_default() as g, session_lib.Session() as sess: 970 training_util.get_or_create_global_step() 971 train_op = training_util._increment_global_step(1) 972 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 973 hook = basic_session_run_hooks.StepCounterHook( 974 summary_writer=summary_writer, every_n_steps=10) 975 hook.begin() 976 self.evaluate(variables_lib.global_variables_initializer()) 977 mon_sess = monitored_session._HookedSession(sess, [hook]) 978 with test.mock.patch.object(tf_logging, 'warning') as mock_log: 979 for _ in range(30): 980 mock_time.return_value += 0.01 981 mon_sess.run(train_op) 982 # logging.warning should not be called. 983 self.assertIsNone(mock_log.call_args) 984 hook.end(sess) 985 summary_writer.assert_summaries( 986 test_case=self, 987 expected_logdir=self.log_dir, 988 expected_graph=g, 989 expected_summaries={}) 990 self.assertItemsEqual([11, 21], summary_writer.summaries.keys()) 991 for step in [11, 21]: 992 summary_value = summary_writer.summaries[step][0].value[0] 993 self.assertEqual('global_step/sec', summary_value.tag) 994 self.assertGreater(summary_value.simple_value, 0) 995 996 @test.mock.patch.object(time, 'time') 997 def test_step_counter_every_n_secs(self, mock_time): 998 mock_time.return_value = MOCK_START_TIME 999 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1000 training_util.get_or_create_global_step() 1001 train_op = training_util._increment_global_step(1) 1002 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 1003 hook = basic_session_run_hooks.StepCounterHook( 1004 summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1) 1005 1006 hook.begin() 1007 self.evaluate(variables_lib.global_variables_initializer()) 1008 mon_sess = monitored_session._HookedSession(sess, [hook]) 1009 mon_sess.run(train_op) 1010 mock_time.return_value += 0.2 1011 mon_sess.run(train_op) 1012 mock_time.return_value += 0.2 1013 mon_sess.run(train_op) 1014 hook.end(sess) 1015 1016 summary_writer.assert_summaries( 1017 test_case=self, 1018 expected_logdir=self.log_dir, 1019 expected_graph=g, 1020 expected_summaries={}) 1021 self.assertTrue(summary_writer.summaries, 'No summaries were created.') 1022 self.assertItemsEqual([2, 3], summary_writer.summaries.keys()) 1023 for summary in summary_writer.summaries.values(): 1024 summary_value = summary[0].value[0] 1025 self.assertEqual('global_step/sec', summary_value.tag) 1026 self.assertGreater(summary_value.simple_value, 0) 1027 1028 def test_global_step_name(self): 1029 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1030 with variable_scope.variable_scope('bar'): 1031 variable_scope.get_variable( 1032 'foo', 1033 initializer=0, 1034 trainable=False, 1035 collections=[ 1036 ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES 1037 ]) 1038 train_op = training_util._increment_global_step(1) 1039 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 1040 hook = basic_session_run_hooks.StepCounterHook( 1041 summary_writer=summary_writer, every_n_steps=1, every_n_secs=None) 1042 1043 hook.begin() 1044 self.evaluate(variables_lib.global_variables_initializer()) 1045 mon_sess = monitored_session._HookedSession(sess, [hook]) 1046 mon_sess.run(train_op) 1047 mon_sess.run(train_op) 1048 hook.end(sess) 1049 1050 summary_writer.assert_summaries( 1051 test_case=self, 1052 expected_logdir=self.log_dir, 1053 expected_graph=g, 1054 expected_summaries={}) 1055 self.assertTrue(summary_writer.summaries, 'No summaries were created.') 1056 self.assertItemsEqual([2], summary_writer.summaries.keys()) 1057 summary_value = summary_writer.summaries[2][0].value[0] 1058 self.assertEqual('bar/foo/sec', summary_value.tag) 1059 1060 def test_log_warning_if_global_step_not_increased(self): 1061 with ops.Graph().as_default(), session_lib.Session() as sess: 1062 training_util.get_or_create_global_step() 1063 train_op = training_util._increment_global_step(0) # keep same. 1064 self.evaluate(variables_lib.global_variables_initializer()) 1065 hook = basic_session_run_hooks.StepCounterHook( 1066 every_n_steps=1, every_n_secs=None) 1067 hook.begin() 1068 mon_sess = monitored_session._HookedSession(sess, [hook]) 1069 mon_sess.run(train_op) # Run one step to record global step. 1070 with test.mock.patch.object(tf_logging, 'log_first_n') as mock_log: 1071 for _ in range(30): 1072 mon_sess.run(train_op) 1073 self.assertRegex( 1074 str(mock_log.call_args), 'global step.*has not been increased') 1075 hook.end(sess) 1076 1077 def _setup_steps_per_run_test(self, 1078 every_n_steps, 1079 steps_per_run, 1080 graph, 1081 sess): 1082 training_util.get_or_create_global_step() 1083 self.train_op = training_util._increment_global_step(steps_per_run) 1084 self.summary_writer = fake_summary_writer.FakeSummaryWriter( 1085 self.log_dir, graph) 1086 self.hook = basic_session_run_hooks.StepCounterHook( 1087 summary_writer=self.summary_writer, every_n_steps=every_n_steps) 1088 self.hook._set_steps_per_run(steps_per_run) 1089 self.hook.begin() 1090 self.evaluate(variables_lib.global_variables_initializer()) 1091 self.mon_sess = monitored_session._HookedSession(sess, [self.hook]) 1092 1093 @test.mock.patch.object(time, 'time') 1094 def test_steps_per_run_less_than_every_n_steps(self, mock_time): 1095 mock_time.return_value = MOCK_START_TIME 1096 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1097 self._setup_steps_per_run_test(10, 5, g, sess) 1098 1099 # Logs at 15, 25 1100 for _ in range(5): 1101 mock_time.return_value += 0.01 1102 self.mon_sess.run(self.train_op) 1103 1104 self.hook.end(sess) 1105 self.summary_writer.assert_summaries( 1106 test_case=self, 1107 expected_logdir=self.log_dir, 1108 expected_graph=g, 1109 expected_summaries={}) 1110 self.assertItemsEqual([15, 25], self.summary_writer.summaries.keys()) 1111 for step in [15, 25]: 1112 summary_value = self.summary_writer.summaries[step][0].value[0] 1113 self.assertEqual('global_step/sec', summary_value.tag) 1114 self.assertGreater(summary_value.simple_value, 0) 1115 1116 @test.mock.patch.object(time, 'time') 1117 def test_steps_per_run_equal_every_n_steps(self, mock_time): 1118 mock_time.return_value = MOCK_START_TIME 1119 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1120 self._setup_steps_per_run_test(5, 5, g, sess) 1121 1122 # Logs at 10, 15, 20, 25 1123 for _ in range(5): 1124 mock_time.return_value += 0.01 1125 self.mon_sess.run(self.train_op) 1126 1127 self.hook.end(sess) 1128 self.summary_writer.assert_summaries( 1129 test_case=self, 1130 expected_logdir=self.log_dir, 1131 expected_graph=g, 1132 expected_summaries={}) 1133 self.assertItemsEqual([10, 15, 20, 25], 1134 self.summary_writer.summaries.keys()) 1135 for step in [10, 15, 20, 25]: 1136 summary_value = self.summary_writer.summaries[step][0].value[0] 1137 self.assertEqual('global_step/sec', summary_value.tag) 1138 self.assertGreater(summary_value.simple_value, 0) 1139 1140 @test.mock.patch.object(time, 'time') 1141 def test_steps_per_run_greater_than_every_n_steps(self, mock_time): 1142 mock_time.return_value = MOCK_START_TIME 1143 with ops.Graph().as_default() as g, session_lib.Session() as sess: 1144 self._setup_steps_per_run_test(5, 10, g, sess) 1145 1146 # Logs at 20, 30, 40, 50 1147 for _ in range(5): 1148 mock_time.return_value += 0.01 1149 self.mon_sess.run(self.train_op) 1150 1151 self.hook.end(sess) 1152 self.summary_writer.assert_summaries( 1153 test_case=self, 1154 expected_logdir=self.log_dir, 1155 expected_graph=g, 1156 expected_summaries={}) 1157 self.assertItemsEqual([20, 30, 40, 50], 1158 self.summary_writer.summaries.keys()) 1159 for step in [20, 30, 40, 50]: 1160 summary_value = self.summary_writer.summaries[step][0].value[0] 1161 self.assertEqual('global_step/sec', summary_value.tag) 1162 self.assertGreater(summary_value.simple_value, 0) 1163 1164 1165@test_util.run_deprecated_v1 1166class SummarySaverHookTest(test.TestCase): 1167 1168 def setUp(self): 1169 test.TestCase.setUp(self) 1170 1171 self.log_dir = 'log/dir' 1172 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir) 1173 1174 var = variables_lib.Variable(0.0) 1175 tensor = state_ops.assign_add(var, 1.0) 1176 tensor2 = tensor * 2 1177 self.summary_op = summary_lib.scalar('my_summary', tensor) 1178 self.summary_op2 = summary_lib.scalar('my_summary2', tensor2) 1179 1180 training_util.get_or_create_global_step() 1181 self.train_op = training_util._increment_global_step(1) 1182 1183 def test_raise_when_scaffold_and_summary_op_both_missing(self): 1184 with self.assertRaises(ValueError): 1185 basic_session_run_hooks.SummarySaverHook() 1186 1187 def test_raise_when_scaffold_and_summary_op_both_present(self): 1188 with self.assertRaises(ValueError): 1189 basic_session_run_hooks.SummarySaverHook( 1190 scaffold=monitored_session.Scaffold(), summary_op=self.summary_op) 1191 1192 def test_raise_in_both_secs_and_steps(self): 1193 with self.assertRaises(ValueError): 1194 basic_session_run_hooks.SummarySaverHook( 1195 save_secs=10, save_steps=20, summary_writer=self.summary_writer) 1196 1197 def test_raise_in_none_secs_and_steps(self): 1198 with self.assertRaises(ValueError): 1199 basic_session_run_hooks.SummarySaverHook( 1200 save_secs=None, save_steps=None, summary_writer=self.summary_writer) 1201 1202 def test_save_steps(self): 1203 hook = basic_session_run_hooks.SummarySaverHook( 1204 save_steps=8, 1205 summary_writer=self.summary_writer, 1206 summary_op=self.summary_op) 1207 1208 with self.cached_session() as sess: 1209 hook.begin() 1210 self.evaluate(variables_lib.global_variables_initializer()) 1211 mon_sess = monitored_session._HookedSession(sess, [hook]) 1212 for _ in range(30): 1213 mon_sess.run(self.train_op) 1214 hook.end(sess) 1215 1216 self.summary_writer.assert_summaries( 1217 test_case=self, 1218 expected_logdir=self.log_dir, 1219 expected_summaries={ 1220 1: { 1221 'my_summary': 1.0 1222 }, 1223 9: { 1224 'my_summary': 2.0 1225 }, 1226 17: { 1227 'my_summary': 3.0 1228 }, 1229 25: { 1230 'my_summary': 4.0 1231 }, 1232 }) 1233 1234 def test_multiple_summaries(self): 1235 hook = basic_session_run_hooks.SummarySaverHook( 1236 save_steps=8, 1237 summary_writer=self.summary_writer, 1238 summary_op=[self.summary_op, self.summary_op2]) 1239 1240 with self.cached_session() as sess: 1241 hook.begin() 1242 self.evaluate(variables_lib.global_variables_initializer()) 1243 mon_sess = monitored_session._HookedSession(sess, [hook]) 1244 for _ in range(10): 1245 mon_sess.run(self.train_op) 1246 hook.end(sess) 1247 1248 self.summary_writer.assert_summaries( 1249 test_case=self, 1250 expected_logdir=self.log_dir, 1251 expected_summaries={ 1252 1: { 1253 'my_summary': 1.0, 1254 'my_summary2': 2.0 1255 }, 1256 9: { 1257 'my_summary': 2.0, 1258 'my_summary2': 4.0 1259 }, 1260 }) 1261 1262 @test.mock.patch.object(time, 'time') 1263 def test_save_secs_saving_once_every_step(self, mock_time): 1264 mock_time.return_value = MOCK_START_TIME 1265 hook = basic_session_run_hooks.SummarySaverHook( 1266 save_secs=0.5, 1267 summary_writer=self.summary_writer, 1268 summary_op=self.summary_op) 1269 1270 with self.cached_session() as sess: 1271 hook.begin() 1272 self.evaluate(variables_lib.global_variables_initializer()) 1273 mon_sess = monitored_session._HookedSession(sess, [hook]) 1274 for _ in range(4): 1275 mon_sess.run(self.train_op) 1276 mock_time.return_value += 0.5 1277 hook.end(sess) 1278 1279 self.summary_writer.assert_summaries( 1280 test_case=self, 1281 expected_logdir=self.log_dir, 1282 expected_summaries={ 1283 1: { 1284 'my_summary': 1.0 1285 }, 1286 2: { 1287 'my_summary': 2.0 1288 }, 1289 3: { 1290 'my_summary': 3.0 1291 }, 1292 4: { 1293 'my_summary': 4.0 1294 }, 1295 }) 1296 1297 @test.mock.patch.object(time, 'time') 1298 def test_save_secs_saving_once_every_three_steps(self, mock_time): 1299 mock_time.return_value = 1484695987.209386 1300 hook = basic_session_run_hooks.SummarySaverHook( 1301 save_secs=9., 1302 summary_writer=self.summary_writer, 1303 summary_op=self.summary_op) 1304 1305 with self.cached_session() as sess: 1306 hook.begin() 1307 self.evaluate(variables_lib.global_variables_initializer()) 1308 mon_sess = monitored_session._HookedSession(sess, [hook]) 1309 for _ in range(8): 1310 mon_sess.run(self.train_op) 1311 mock_time.return_value += 3.1 1312 hook.end(sess) 1313 1314 # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first: 1315 self.summary_writer.assert_summaries( 1316 test_case=self, 1317 expected_logdir=self.log_dir, 1318 expected_summaries={ 1319 1: { 1320 'my_summary': 1.0 1321 }, 1322 4: { 1323 'my_summary': 2.0 1324 }, 1325 7: { 1326 'my_summary': 3.0 1327 }, 1328 }) 1329 1330 1331class GlobalStepWaiterHookTest(test.TestCase): 1332 1333 def test_not_wait_for_step_zero(self): 1334 with ops.Graph().as_default(): 1335 training_util.get_or_create_global_step() 1336 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0) 1337 hook.begin() 1338 with session_lib.Session() as sess: 1339 # Before run should return without waiting gstep increment. 1340 hook.before_run( 1341 session_run_hook.SessionRunContext( 1342 original_args=None, session=sess)) 1343 1344 @test.mock.patch.object(time, 'sleep') 1345 def test_wait_for_step(self, mock_sleep): 1346 with ops.Graph().as_default(): 1347 gstep = training_util.get_or_create_global_step() 1348 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000) 1349 hook.begin() 1350 1351 with session_lib.Session() as sess: 1352 # Mock out calls to time.sleep() to update the global step. 1353 1354 class Context(object): 1355 counter = 0 1356 1357 def mock_sleep_side_effect(seconds): 1358 del seconds # argument is ignored 1359 Context.counter += 1 1360 if Context.counter == 1: 1361 # The first time sleep() is called, we update the global_step from 1362 # 0 to 500. 1363 sess.run(state_ops.assign(gstep, 500)) 1364 elif Context.counter == 2: 1365 # The second time sleep() is called, we update the global_step from 1366 # 500 to 1100. 1367 sess.run(state_ops.assign(gstep, 1100)) 1368 else: 1369 raise AssertionError( 1370 'Expected before_run() to terminate after the second call to ' 1371 'time.sleep()') 1372 1373 mock_sleep.side_effect = mock_sleep_side_effect 1374 1375 # Run the mocked-out interaction with the hook. 1376 self.evaluate(variables_lib.global_variables_initializer()) 1377 run_context = session_run_hook.SessionRunContext( 1378 original_args=None, session=sess) 1379 hook.before_run(run_context) 1380 self.assertEqual(Context.counter, 2) 1381 1382 1383class FinalOpsHookTest(test.TestCase): 1384 1385 def test_final_ops_is_scalar_tensor(self): 1386 with ops.Graph().as_default(): 1387 expected_value = 4 1388 final_ops = constant_op.constant(expected_value) 1389 1390 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1391 hook.begin() 1392 1393 with session_lib.Session() as session: 1394 hook.end(session) 1395 self.assertEqual(expected_value, 1396 hook.final_ops_values) 1397 1398 def test_final_ops_is_tensor(self): 1399 with ops.Graph().as_default(): 1400 expected_values = [1, 6, 3, 5, 2, 4] 1401 final_ops = constant_op.constant(expected_values) 1402 1403 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1404 hook.begin() 1405 1406 with session_lib.Session() as session: 1407 hook.end(session) 1408 self.assertListEqual(expected_values, 1409 hook.final_ops_values.tolist()) 1410 1411 def test_final_ops_triggers_out_of_range_error(self): 1412 with ops.Graph().as_default(): 1413 dataset = dataset_ops.Dataset.range(1) 1414 iterator = dataset_ops.make_one_shot_iterator(dataset) 1415 read_ops = iterator.get_next() 1416 final_ops = read_ops 1417 1418 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1419 hook.begin() 1420 1421 with session_lib.Session() as session: 1422 session.run(read_ops) 1423 with test.mock.patch.object(tf_logging, 'warning') as mock_log: 1424 with self.assertRaisesRegex(errors.OutOfRangeError, 1425 'End of sequence'): 1426 hook.end(session) 1427 self.assertRegex( 1428 str(mock_log.call_args), 'dependency back to some input source') 1429 1430 def test_final_ops_with_dictionary(self): 1431 with ops.Graph().as_default(): 1432 expected_values = [4, -3] 1433 final_ops = array_ops.placeholder(dtype=dtypes.float32) 1434 final_ops_feed_dict = {final_ops: expected_values} 1435 1436 hook = basic_session_run_hooks.FinalOpsHook( 1437 final_ops, final_ops_feed_dict) 1438 hook.begin() 1439 1440 with session_lib.Session() as session: 1441 hook.end(session) 1442 self.assertListEqual(expected_values, 1443 hook.final_ops_values.tolist()) 1444 1445 1446@test_util.run_deprecated_v1 1447class ResourceSummarySaverHookTest(test.TestCase): 1448 1449 def setUp(self): 1450 test.TestCase.setUp(self) 1451 1452 self.log_dir = 'log/dir' 1453 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir) 1454 1455 var = variable_scope.get_variable('var', initializer=0.0, use_resource=True) 1456 tensor = state_ops.assign_add(var, 1.0) 1457 self.summary_op = summary_lib.scalar('my_summary', tensor) 1458 1459 with variable_scope.variable_scope('foo', use_resource=True): 1460 training_util.create_global_step() 1461 self.train_op = training_util._increment_global_step(1) 1462 1463 def test_save_steps(self): 1464 hook = basic_session_run_hooks.SummarySaverHook( 1465 save_steps=8, 1466 summary_writer=self.summary_writer, 1467 summary_op=self.summary_op) 1468 1469 with self.cached_session() as sess: 1470 hook.begin() 1471 self.evaluate(variables_lib.global_variables_initializer()) 1472 mon_sess = monitored_session._HookedSession(sess, [hook]) 1473 for _ in range(30): 1474 mon_sess.run(self.train_op) 1475 hook.end(sess) 1476 1477 self.summary_writer.assert_summaries( 1478 test_case=self, 1479 expected_logdir=self.log_dir, 1480 expected_summaries={ 1481 1: { 1482 'my_summary': 1.0 1483 }, 1484 9: { 1485 'my_summary': 2.0 1486 }, 1487 17: { 1488 'my_summary': 3.0 1489 }, 1490 25: { 1491 'my_summary': 4.0 1492 }, 1493 }) 1494 1495 1496class FeedFnHookTest(test.TestCase): 1497 1498 def test_feeding_placeholder(self): 1499 with ops.Graph().as_default(), session_lib.Session() as sess: 1500 x = array_ops.placeholder(dtype=dtypes.float32) 1501 y = x + 1 1502 hook = basic_session_run_hooks.FeedFnHook( 1503 feed_fn=lambda: {x: 1.0}) 1504 hook.begin() 1505 mon_sess = monitored_session._HookedSession(sess, [hook]) 1506 self.assertEqual(mon_sess.run(y), 2) 1507 1508 1509class ProfilerHookTest(test.TestCase): 1510 1511 def setUp(self): 1512 super(ProfilerHookTest, self).setUp() 1513 self.output_dir = tempfile.mkdtemp() 1514 self.graph = ops.Graph() 1515 self.filepattern = os.path.join(self.output_dir, 'timeline-*.json') 1516 with self.graph.as_default(): 1517 self.global_step = training_util.get_or_create_global_step() 1518 self.train_op = state_ops.assign_add(self.global_step, 1) 1519 1520 def tearDown(self): 1521 super(ProfilerHookTest, self).tearDown() 1522 shutil.rmtree(self.output_dir, ignore_errors=True) 1523 1524 def _count_timeline_files(self): 1525 return len(gfile.Glob(self.filepattern)) 1526 1527 @test_util.run_deprecated_v1 1528 def test_raise_in_both_secs_and_steps(self): 1529 with self.assertRaises(ValueError): 1530 basic_session_run_hooks.ProfilerHook(save_secs=10, save_steps=20) 1531 1532 @test_util.run_deprecated_v1 1533 def test_raise_in_none_secs_and_steps(self): 1534 with self.assertRaises(ValueError): 1535 basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None) 1536 1537 def test_save_secs_does_not_save_in_first_step(self): 1538 with self.graph.as_default(): 1539 hook = basic_session_run_hooks.ProfilerHook( 1540 save_secs=2, output_dir=self.output_dir) 1541 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1542 sess.run(self.train_op) 1543 self.assertEqual(0, self._count_timeline_files()) 1544 1545 @test.mock.patch.object(time, 'time') 1546 def test_save_secs_saves_periodically(self, mock_time): 1547 # Pick a fixed start time. 1548 with self.graph.as_default(): 1549 mock_time.return_value = MOCK_START_TIME 1550 hook = basic_session_run_hooks.ProfilerHook( 1551 save_secs=2, output_dir=self.output_dir) 1552 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1553 sess.run(self.train_op) # Not saved. 1554 self.assertEqual(0, self._count_timeline_files()) 1555 # Simulate 2.5 seconds of sleep. 1556 mock_time.return_value = MOCK_START_TIME + 2.5 1557 sess.run(self.train_op) # Saved. 1558 self.assertEqual(1, self._count_timeline_files()) 1559 1560 # Pretend some small amount of time has passed. 1561 mock_time.return_value = MOCK_START_TIME + 2.6 1562 sess.run(self.train_op) # Not saved. 1563 # Edge test just before we should save the timeline. 1564 mock_time.return_value = MOCK_START_TIME + 4.4 1565 sess.run(self.train_op) # Not saved. 1566 self.assertEqual(1, self._count_timeline_files()) 1567 1568 mock_time.return_value = MOCK_START_TIME + 4.5 1569 sess.run(self.train_op) # Saved. 1570 self.assertEqual(2, self._count_timeline_files()) 1571 1572 def test_save_steps_does_not_save_in_first_step(self): 1573 with self.graph.as_default(): 1574 hook = basic_session_run_hooks.ProfilerHook( 1575 save_steps=1, output_dir=self.output_dir) 1576 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1577 sess.run(self.train_op) # Not saved. 1578 self.assertEqual(0, self._count_timeline_files()) 1579 1580 def test_save_steps_saves_periodically(self): 1581 with self.graph.as_default(): 1582 hook = basic_session_run_hooks.ProfilerHook( 1583 save_steps=2, output_dir=self.output_dir) 1584 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1585 self.assertEqual(0, self._count_timeline_files()) 1586 sess.run(self.train_op) # Not saved. 1587 self.assertEqual(0, self._count_timeline_files()) 1588 sess.run(self.train_op) # Saved. 1589 self.assertEqual(1, self._count_timeline_files()) 1590 sess.run(self.train_op) # Not saved. 1591 self.assertEqual(1, self._count_timeline_files()) 1592 sess.run(self.train_op) # Saved. 1593 self.assertEqual(2, self._count_timeline_files()) 1594 sess.run(self.train_op) # Not saved. 1595 self.assertEqual(2, self._count_timeline_files()) 1596 1597 def test_run_metadata_saves(self): 1598 writer_cache.FileWriterCache.clear() 1599 fake_summary_writer.FakeSummaryWriter.install() 1600 fake_writer = writer_cache.FileWriterCache.get(self.output_dir) 1601 with self.graph.as_default(): 1602 hook = basic_session_run_hooks.ProfilerHook( 1603 save_steps=1, output_dir=self.output_dir) 1604 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1605 sess.run(self.train_op) # Not saved. 1606 sess.run(self.train_op) # Saved. 1607 self.assertEqual( 1608 list(fake_writer._added_run_metadata.keys()), ['step_2']) 1609 fake_summary_writer.FakeSummaryWriter.uninstall() 1610 1611 1612if __name__ == '__main__': 1613 test.main() 1614