• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pylint: disable=g-bad-file-header
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for basic_session_run_hooks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os.path
23import shutil
24import tempfile
25import time
26
27from tensorflow.contrib.framework.python.framework import checkpoint_utils
28from tensorflow.contrib.framework.python.ops import variables
29from tensorflow.contrib.testing.python.framework import fake_summary_writer
30from tensorflow.python.client import session as session_lib
31from tensorflow.python.data.ops import dataset_ops
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import meta_graph
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import test_util
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables as variables_lib
43import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
44from tensorflow.python.platform import gfile
45from tensorflow.python.platform import test
46from tensorflow.python.platform import tf_logging
47from tensorflow.python.summary import summary as summary_lib
48from tensorflow.python.summary.writer import writer_cache
49from tensorflow.python.training import basic_session_run_hooks
50from tensorflow.python.training import monitored_session
51from tensorflow.python.training import session_run_hook
52from tensorflow.python.training import training_util
53
54
55# Provide a realistic start time for unit tests where we need to mock out
56# calls to time.time().
57MOCK_START_TIME = 1484695987.209386
58
59
60class MockCheckpointSaverListener(
61    basic_session_run_hooks.CheckpointSaverListener):
62
63  def __init__(self):
64    self.begin_count = 0
65    self.before_save_count = 0
66    self.after_save_count = 0
67    self.end_count = 0
68    self.ask_for_stop = False
69
70  def begin(self):
71    self.begin_count += 1
72
73  def before_save(self, session, global_step):
74    self.before_save_count += 1
75
76  def after_save(self, session, global_step):
77    self.after_save_count += 1
78    if self.ask_for_stop:
79      return True
80
81  def end(self, session, global_step):
82    self.end_count += 1
83
84  def get_counts(self):
85    return {
86        'begin': self.begin_count,
87        'before_save': self.before_save_count,
88        'after_save': self.after_save_count,
89        'end': self.end_count
90    }
91
92
93class SecondOrStepTimerTest(test.TestCase):
94
95  @test_util.run_deprecated_v1
96  def test_raise_in_both_secs_and_steps(self):
97    with self.assertRaises(ValueError):
98      basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10)
99
100  @test_util.run_deprecated_v1
101  def test_raise_in_none_secs_and_steps(self):
102    with self.assertRaises(ValueError):
103      basic_session_run_hooks.SecondOrStepTimer()
104
105  @test.mock.patch.object(time, 'time')
106  def test_every_secs(self, mock_time):
107    mock_time.return_value = MOCK_START_TIME
108    timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)
109    self.assertTrue(timer.should_trigger_for_step(1))
110
111    timer.update_last_triggered_step(1)
112    self.assertFalse(timer.should_trigger_for_step(1))
113    self.assertFalse(timer.should_trigger_for_step(2))
114
115    mock_time.return_value += 1.0
116    self.assertFalse(timer.should_trigger_for_step(1))
117    self.assertTrue(timer.should_trigger_for_step(2))
118
119  def test_every_steps(self):
120    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3)
121    self.assertTrue(timer.should_trigger_for_step(1))
122
123    timer.update_last_triggered_step(1)
124    self.assertFalse(timer.should_trigger_for_step(1))
125    self.assertFalse(timer.should_trigger_for_step(2))
126    self.assertFalse(timer.should_trigger_for_step(3))
127    self.assertTrue(timer.should_trigger_for_step(4))
128
129  def test_update_last_triggered_step(self):
130    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1)
131
132    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1)
133    self.assertEqual(None, elapsed_secs)
134    self.assertEqual(None, elapsed_steps)
135
136    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5)
137    self.assertLess(0, elapsed_secs)
138    self.assertEqual(4, elapsed_steps)
139
140    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7)
141    self.assertLess(0, elapsed_secs)
142    self.assertEqual(2, elapsed_steps)
143
144
145class StopAtStepTest(test.TestCase):
146
147  def test_raise_in_both_last_step_and_num_steps(self):
148    with self.assertRaises(ValueError):
149      basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20)
150
151  def test_stop_based_on_last_step(self):
152    h = basic_session_run_hooks.StopAtStepHook(last_step=10)
153    with ops.Graph().as_default():
154      global_step = variables.get_or_create_global_step()
155      no_op = control_flow_ops.no_op()
156      h.begin()
157      with session_lib.Session() as sess:
158        mon_sess = monitored_session._HookedSession(sess, [h])
159        sess.run(state_ops.assign(global_step, 5))
160        h.after_create_session(sess, None)
161        mon_sess.run(no_op)
162        self.assertFalse(mon_sess.should_stop())
163        sess.run(state_ops.assign(global_step, 9))
164        mon_sess.run(no_op)
165        self.assertFalse(mon_sess.should_stop())
166        sess.run(state_ops.assign(global_step, 10))
167        mon_sess.run(no_op)
168        self.assertTrue(mon_sess.should_stop())
169        sess.run(state_ops.assign(global_step, 11))
170        mon_sess._should_stop = False
171        mon_sess.run(no_op)
172        self.assertTrue(mon_sess.should_stop())
173
174  def test_stop_based_on_num_step(self):
175    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
176
177    with ops.Graph().as_default():
178      global_step = variables.get_or_create_global_step()
179      no_op = control_flow_ops.no_op()
180      h.begin()
181      with session_lib.Session() as sess:
182        mon_sess = monitored_session._HookedSession(sess, [h])
183        sess.run(state_ops.assign(global_step, 5))
184        h.after_create_session(sess, None)
185        mon_sess.run(no_op)
186        self.assertFalse(mon_sess.should_stop())
187        sess.run(state_ops.assign(global_step, 13))
188        mon_sess.run(no_op)
189        self.assertFalse(mon_sess.should_stop())
190        sess.run(state_ops.assign(global_step, 14))
191        mon_sess.run(no_op)
192        self.assertFalse(mon_sess.should_stop())
193        sess.run(state_ops.assign(global_step, 15))
194        mon_sess.run(no_op)
195        self.assertTrue(mon_sess.should_stop())
196        sess.run(state_ops.assign(global_step, 16))
197        mon_sess._should_stop = False
198        mon_sess.run(no_op)
199        self.assertTrue(mon_sess.should_stop())
200
201  def test_stop_based_with_multiple_steps(self):
202    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
203
204    with ops.Graph().as_default():
205      global_step = variables.get_or_create_global_step()
206      no_op = control_flow_ops.no_op()
207      h.begin()
208      with session_lib.Session() as sess:
209        mon_sess = monitored_session._HookedSession(sess, [h])
210        sess.run(state_ops.assign(global_step, 5))
211        h.after_create_session(sess, None)
212        mon_sess.run(no_op)
213        self.assertFalse(mon_sess.should_stop())
214        sess.run(state_ops.assign(global_step, 15))
215        mon_sess.run(no_op)
216        self.assertTrue(mon_sess.should_stop())
217
218
219class LoggingTensorHookTest(test.TestCase):
220
221  def setUp(self):
222    # Mock out logging calls so we can verify whether correct tensors are being
223    # monitored.
224    self._actual_log = tf_logging.info
225    self.logged_message = None
226
227    def mock_log(*args, **kwargs):
228      self.logged_message = args
229      self._actual_log(*args, **kwargs)
230
231    tf_logging.info = mock_log
232
233  def tearDown(self):
234    tf_logging.info = self._actual_log
235
236  def test_illegal_args(self):
237    with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
238      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0)
239    with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
240      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10)
241    with self.assertRaisesRegexp(ValueError, 'xactly one of'):
242      basic_session_run_hooks.LoggingTensorHook(
243          tensors=['t'], every_n_iter=5, every_n_secs=5)
244    with self.assertRaisesRegexp(ValueError, 'xactly one of'):
245      basic_session_run_hooks.LoggingTensorHook(tensors=['t'])
246
247  def test_print_at_end_only(self):
248    with ops.Graph().as_default(), session_lib.Session() as sess:
249      t = constant_op.constant(42.0, name='foo')
250      train_op = constant_op.constant(3)
251      hook = basic_session_run_hooks.LoggingTensorHook(
252          tensors=[t.name], at_end=True)
253      hook.begin()
254      mon_sess = monitored_session._HookedSession(sess, [hook])
255      self.evaluate(variables_lib.global_variables_initializer())
256      self.logged_message = ''
257      for _ in range(3):
258        mon_sess.run(train_op)
259        # assertNotRegexpMatches is not supported by python 3.1 and later
260        self.assertEqual(str(self.logged_message).find(t.name), -1)
261
262      hook.end(sess)
263      self.assertRegexpMatches(str(self.logged_message), t.name)
264
265  def _validate_print_every_n_steps(self, sess, at_end):
266    t = constant_op.constant(42.0, name='foo')
267
268    train_op = constant_op.constant(3)
269    hook = basic_session_run_hooks.LoggingTensorHook(
270        tensors=[t.name], every_n_iter=10, at_end=at_end)
271    hook.begin()
272    mon_sess = monitored_session._HookedSession(sess, [hook])
273    self.evaluate(variables_lib.global_variables_initializer())
274    mon_sess.run(train_op)
275    self.assertRegexpMatches(str(self.logged_message), t.name)
276    for _ in range(3):
277      self.logged_message = ''
278      for _ in range(9):
279        mon_sess.run(train_op)
280        # assertNotRegexpMatches is not supported by python 3.1 and later
281        self.assertEqual(str(self.logged_message).find(t.name), -1)
282      mon_sess.run(train_op)
283      self.assertRegexpMatches(str(self.logged_message), t.name)
284
285    # Add additional run to verify proper reset when called multiple times.
286    self.logged_message = ''
287    mon_sess.run(train_op)
288    # assertNotRegexpMatches is not supported by python 3.1 and later
289    self.assertEqual(str(self.logged_message).find(t.name), -1)
290
291    self.logged_message = ''
292    hook.end(sess)
293    if at_end:
294      self.assertRegexpMatches(str(self.logged_message), t.name)
295    else:
296      # assertNotRegexpMatches is not supported by python 3.1 and later
297      self.assertEqual(str(self.logged_message).find(t.name), -1)
298
299  def test_print_every_n_steps(self):
300    with ops.Graph().as_default(), session_lib.Session() as sess:
301      self._validate_print_every_n_steps(sess, at_end=False)
302      # Verify proper reset.
303      self._validate_print_every_n_steps(sess, at_end=False)
304
305  def test_print_every_n_steps_and_end(self):
306    with ops.Graph().as_default(), session_lib.Session() as sess:
307      self._validate_print_every_n_steps(sess, at_end=True)
308      # Verify proper reset.
309      self._validate_print_every_n_steps(sess, at_end=True)
310
311  def test_print_first_step(self):
312    # if it runs every iteration, first iteration has None duration.
313    with ops.Graph().as_default(), session_lib.Session() as sess:
314      t = constant_op.constant(42.0, name='foo')
315      train_op = constant_op.constant(3)
316      hook = basic_session_run_hooks.LoggingTensorHook(
317          tensors={'foo': t}, every_n_iter=1)
318      hook.begin()
319      mon_sess = monitored_session._HookedSession(sess, [hook])
320      self.evaluate(variables_lib.global_variables_initializer())
321      mon_sess.run(train_op)
322      self.assertRegexpMatches(str(self.logged_message), 'foo')
323      # in first run, elapsed time is None.
324      self.assertEqual(str(self.logged_message).find('sec'), -1)
325
326  def _validate_print_every_n_secs(self, sess, at_end, mock_time):
327    t = constant_op.constant(42.0, name='foo')
328    train_op = constant_op.constant(3)
329
330    hook = basic_session_run_hooks.LoggingTensorHook(
331        tensors=[t.name], every_n_secs=1.0, at_end=at_end)
332    hook.begin()
333    mon_sess = monitored_session._HookedSession(sess, [hook])
334    self.evaluate(variables_lib.global_variables_initializer())
335
336    mon_sess.run(train_op)
337    self.assertRegexpMatches(str(self.logged_message), t.name)
338
339    # assertNotRegexpMatches is not supported by python 3.1 and later
340    self.logged_message = ''
341    mon_sess.run(train_op)
342    self.assertEqual(str(self.logged_message).find(t.name), -1)
343    mock_time.return_value += 1.0
344
345    self.logged_message = ''
346    mon_sess.run(train_op)
347    self.assertRegexpMatches(str(self.logged_message), t.name)
348
349    self.logged_message = ''
350    hook.end(sess)
351    if at_end:
352      self.assertRegexpMatches(str(self.logged_message), t.name)
353    else:
354      # assertNotRegexpMatches is not supported by python 3.1 and later
355      self.assertEqual(str(self.logged_message).find(t.name), -1)
356
357  @test.mock.patch.object(time, 'time')
358  def test_print_every_n_secs(self, mock_time):
359    with ops.Graph().as_default(), session_lib.Session() as sess:
360      mock_time.return_value = MOCK_START_TIME
361      self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)
362      # Verify proper reset.
363      self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)
364
365  @test.mock.patch.object(time, 'time')
366  def test_print_every_n_secs_and_end(self, mock_time):
367    with ops.Graph().as_default(), session_lib.Session() as sess:
368      mock_time.return_value = MOCK_START_TIME
369      self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)
370      # Verify proper reset.
371      self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)
372
373  def test_print_formatter(self):
374    with ops.Graph().as_default(), session_lib.Session() as sess:
375      t = constant_op.constant(42.0, name='foo')
376      train_op = constant_op.constant(3)
377      hook = basic_session_run_hooks.LoggingTensorHook(
378          tensors=[t.name], every_n_iter=10,
379          formatter=lambda items: 'qqq=%s' % items[t.name])
380      hook.begin()
381      mon_sess = monitored_session._HookedSession(sess, [hook])
382      self.evaluate(variables_lib.global_variables_initializer())
383      mon_sess.run(train_op)
384      self.assertEqual(self.logged_message[0], 'qqq=42.0')
385
386
387class CheckpointSaverHookTest(test.TestCase):
388
389  def setUp(self):
390    self.model_dir = tempfile.mkdtemp()
391    self.graph = ops.Graph()
392    with self.graph.as_default():
393      self.scaffold = monitored_session.Scaffold()
394      self.global_step = variables.get_or_create_global_step()
395      self.train_op = training_util._increment_global_step(1)
396
397  def tearDown(self):
398    shutil.rmtree(self.model_dir, ignore_errors=True)
399
400  def test_saves_when_saver_and_scaffold_both_missing(self):
401    with self.graph.as_default():
402      hook = basic_session_run_hooks.CheckpointSaverHook(
403          self.model_dir, save_steps=1)
404      hook.begin()
405      self.scaffold.finalize()
406      with session_lib.Session() as sess:
407        sess.run(self.scaffold.init_op)
408        mon_sess = monitored_session._HookedSession(sess, [hook])
409        mon_sess.run(self.train_op)
410        self.assertEqual(1,
411                         checkpoint_utils.load_variable(self.model_dir,
412                                                        self.global_step.name))
413
414  def test_raise_when_saver_and_scaffold_both_present(self):
415    with self.assertRaises(ValueError):
416      basic_session_run_hooks.CheckpointSaverHook(
417          self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)
418
419  @test_util.run_deprecated_v1
420  def test_raise_in_both_secs_and_steps(self):
421    with self.assertRaises(ValueError):
422      basic_session_run_hooks.CheckpointSaverHook(
423          self.model_dir, save_secs=10, save_steps=20)
424
425  @test_util.run_deprecated_v1
426  def test_raise_in_none_secs_and_steps(self):
427    with self.assertRaises(ValueError):
428      basic_session_run_hooks.CheckpointSaverHook(self.model_dir)
429
430  def test_save_secs_saves_in_first_step(self):
431    with self.graph.as_default():
432      hook = basic_session_run_hooks.CheckpointSaverHook(
433          self.model_dir, save_secs=2, scaffold=self.scaffold)
434      hook.begin()
435      self.scaffold.finalize()
436      with session_lib.Session() as sess:
437        sess.run(self.scaffold.init_op)
438        mon_sess = monitored_session._HookedSession(sess, [hook])
439        mon_sess.run(self.train_op)
440        self.assertEqual(1,
441                         checkpoint_utils.load_variable(self.model_dir,
442                                                        self.global_step.name))
443
444  def test_save_secs_calls_listeners_at_begin_and_end(self):
445    with self.graph.as_default():
446      listener = MockCheckpointSaverListener()
447      hook = basic_session_run_hooks.CheckpointSaverHook(
448          self.model_dir,
449          save_secs=2,
450          scaffold=self.scaffold,
451          listeners=[listener])
452      hook.begin()
453      self.scaffold.finalize()
454      with session_lib.Session() as sess:
455        sess.run(self.scaffold.init_op)
456        mon_sess = monitored_session._HookedSession(sess, [hook])
457        mon_sess.run(self.train_op)  # hook runs here
458        mon_sess.run(self.train_op)  # hook won't run here, so it does at end
459        hook.end(sess)  # hook runs here
460      self.assertEqual({
461          'begin': 1,
462          'before_save': 2,
463          'after_save': 2,
464          'end': 1
465      }, listener.get_counts())
466
467  def test_listener_with_monitored_session(self):
468    with ops.Graph().as_default():
469      scaffold = monitored_session.Scaffold()
470      global_step = variables.get_or_create_global_step()
471      train_op = training_util._increment_global_step(1)
472      listener = MockCheckpointSaverListener()
473      hook = basic_session_run_hooks.CheckpointSaverHook(
474          self.model_dir,
475          save_steps=1,
476          scaffold=scaffold,
477          listeners=[listener])
478      with monitored_session.SingularMonitoredSession(
479          hooks=[hook],
480          scaffold=scaffold,
481          checkpoint_dir=self.model_dir) as sess:
482        sess.run(train_op)
483        sess.run(train_op)
484        global_step_val = sess.raw_session().run(global_step)
485      listener_counts = listener.get_counts()
486    self.assertEqual(2, global_step_val)
487    self.assertEqual({
488        'begin': 1,
489        'before_save': 3,
490        'after_save': 3,
491        'end': 1
492    }, listener_counts)
493
494  def test_listener_stops_training_in_after_save(self):
495    with ops.Graph().as_default():
496      scaffold = monitored_session.Scaffold()
497      variables.get_or_create_global_step()
498      train_op = training_util._increment_global_step(1)
499      listener = MockCheckpointSaverListener()
500      hook = basic_session_run_hooks.CheckpointSaverHook(
501          self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener])
502      with monitored_session.SingularMonitoredSession(
503          hooks=[hook], scaffold=scaffold,
504          checkpoint_dir=self.model_dir) as sess:
505        sess.run(train_op)
506        self.assertFalse(sess.should_stop())
507        sess.run(train_op)
508        self.assertFalse(sess.should_stop())
509        listener.ask_for_stop = True
510        sess.run(train_op)
511        self.assertTrue(sess.should_stop())
512
513  def test_listener_with_default_saver(self):
514    with ops.Graph().as_default():
515      global_step = variables.get_or_create_global_step()
516      train_op = training_util._increment_global_step(1)
517      listener = MockCheckpointSaverListener()
518      hook = basic_session_run_hooks.CheckpointSaverHook(
519          self.model_dir,
520          save_steps=1,
521          listeners=[listener])
522      with monitored_session.SingularMonitoredSession(
523          hooks=[hook],
524          checkpoint_dir=self.model_dir) as sess:
525        sess.run(train_op)
526        sess.run(train_op)
527        global_step_val = sess.raw_session().run(global_step)
528      listener_counts = listener.get_counts()
529    self.assertEqual(2, global_step_val)
530    self.assertEqual({
531        'begin': 1,
532        'before_save': 3,
533        'after_save': 3,
534        'end': 1
535    }, listener_counts)
536
537    with ops.Graph().as_default():
538      global_step = variables.get_or_create_global_step()
539      with monitored_session.SingularMonitoredSession(
540          checkpoint_dir=self.model_dir) as sess2:
541        global_step_saved_val = sess2.run(global_step)
542    self.assertEqual(2, global_step_saved_val)
543
544  def test_two_listeners_with_default_saver(self):
545    with ops.Graph().as_default():
546      global_step = variables.get_or_create_global_step()
547      train_op = training_util._increment_global_step(1)
548      listener1 = MockCheckpointSaverListener()
549      listener2 = MockCheckpointSaverListener()
550      hook = basic_session_run_hooks.CheckpointSaverHook(
551          self.model_dir,
552          save_steps=1,
553          listeners=[listener1, listener2])
554      with monitored_session.SingularMonitoredSession(
555          hooks=[hook],
556          checkpoint_dir=self.model_dir) as sess:
557        sess.run(train_op)
558        sess.run(train_op)
559        global_step_val = sess.raw_session().run(global_step)
560      listener1_counts = listener1.get_counts()
561      listener2_counts = listener2.get_counts()
562    self.assertEqual(2, global_step_val)
563    self.assertEqual({
564        'begin': 1,
565        'before_save': 3,
566        'after_save': 3,
567        'end': 1
568    }, listener1_counts)
569    self.assertEqual(listener1_counts, listener2_counts)
570
571    with ops.Graph().as_default():
572      global_step = variables.get_or_create_global_step()
573      with monitored_session.SingularMonitoredSession(
574          checkpoint_dir=self.model_dir) as sess2:
575        global_step_saved_val = sess2.run(global_step)
576    self.assertEqual(2, global_step_saved_val)
577
578  @test.mock.patch.object(time, 'time')
579  def test_save_secs_saves_periodically(self, mock_time):
580    with self.graph.as_default():
581      mock_time.return_value = MOCK_START_TIME
582      hook = basic_session_run_hooks.CheckpointSaverHook(
583          self.model_dir, save_secs=2, scaffold=self.scaffold)
584      hook.begin()
585      self.scaffold.finalize()
586
587      with session_lib.Session() as sess:
588        sess.run(self.scaffold.init_op)
589        mon_sess = monitored_session._HookedSession(sess, [hook])
590
591        mock_time.return_value = MOCK_START_TIME
592        mon_sess.run(self.train_op)  # Saved.
593
594        mock_time.return_value = MOCK_START_TIME + 0.5
595        mon_sess.run(self.train_op)  # Not saved.
596
597        self.assertEqual(1,
598                         checkpoint_utils.load_variable(self.model_dir,
599                                                        self.global_step.name))
600
601        # Simulate 2.5 seconds of sleep.
602        mock_time.return_value = MOCK_START_TIME + 2.5
603        mon_sess.run(self.train_op)  # Saved.
604
605        mock_time.return_value = MOCK_START_TIME + 2.6
606        mon_sess.run(self.train_op)  # Not saved.
607
608        mock_time.return_value = MOCK_START_TIME + 2.7
609        mon_sess.run(self.train_op)  # Not saved.
610
611        self.assertEqual(3,
612                         checkpoint_utils.load_variable(self.model_dir,
613                                                        self.global_step.name))
614
615        # Simulate 7.5 more seconds of sleep (10 seconds from start.
616        mock_time.return_value = MOCK_START_TIME + 10
617        mon_sess.run(self.train_op)  # Saved.
618        self.assertEqual(6,
619                         checkpoint_utils.load_variable(self.model_dir,
620                                                        self.global_step.name))
621
622  @test.mock.patch.object(time, 'time')
623  def test_save_secs_calls_listeners_periodically(self, mock_time):
624    with self.graph.as_default():
625      mock_time.return_value = MOCK_START_TIME
626      listener = MockCheckpointSaverListener()
627      hook = basic_session_run_hooks.CheckpointSaverHook(
628          self.model_dir,
629          save_secs=2,
630          scaffold=self.scaffold,
631          listeners=[listener])
632      hook.begin()
633      self.scaffold.finalize()
634      with session_lib.Session() as sess:
635        sess.run(self.scaffold.init_op)
636        mon_sess = monitored_session._HookedSession(sess, [hook])
637
638        mock_time.return_value = MOCK_START_TIME + 0.5
639        mon_sess.run(self.train_op)  # hook runs here
640
641        mock_time.return_value = MOCK_START_TIME + 0.5
642        mon_sess.run(self.train_op)
643
644        mock_time.return_value = MOCK_START_TIME + 3.0
645        mon_sess.run(self.train_op)  # hook runs here
646
647        mock_time.return_value = MOCK_START_TIME + 3.5
648        mon_sess.run(self.train_op)
649
650        mock_time.return_value = MOCK_START_TIME + 4.0
651        mon_sess.run(self.train_op)
652
653        mock_time.return_value = MOCK_START_TIME + 6.5
654        mon_sess.run(self.train_op)  # hook runs here
655
656        mock_time.return_value = MOCK_START_TIME + 7.0
657        mon_sess.run(self.train_op)  # hook won't run here, so it does at end
658
659        mock_time.return_value = MOCK_START_TIME + 7.5
660        hook.end(sess)  # hook runs here
661      self.assertEqual({
662          'begin': 1,
663          'before_save': 4,
664          'after_save': 4,
665          'end': 1
666      }, listener.get_counts())
667
668  def test_save_steps_saves_in_first_step(self):
669    with self.graph.as_default():
670      hook = basic_session_run_hooks.CheckpointSaverHook(
671          self.model_dir, save_steps=2, scaffold=self.scaffold)
672      hook.begin()
673      self.scaffold.finalize()
674      with session_lib.Session() as sess:
675        sess.run(self.scaffold.init_op)
676        mon_sess = monitored_session._HookedSession(sess, [hook])
677        mon_sess.run(self.train_op)
678        self.assertEqual(1,
679                         checkpoint_utils.load_variable(self.model_dir,
680                                                        self.global_step.name))
681
682  def test_save_steps_saves_periodically(self):
683    with self.graph.as_default():
684      hook = basic_session_run_hooks.CheckpointSaverHook(
685          self.model_dir, save_steps=2, scaffold=self.scaffold)
686      hook.begin()
687      self.scaffold.finalize()
688      with session_lib.Session() as sess:
689        sess.run(self.scaffold.init_op)
690        mon_sess = monitored_session._HookedSession(sess, [hook])
691        mon_sess.run(self.train_op)
692        mon_sess.run(self.train_op)
693        # Not saved
694        self.assertEqual(1,
695                         checkpoint_utils.load_variable(self.model_dir,
696                                                        self.global_step.name))
697        mon_sess.run(self.train_op)
698        # saved
699        self.assertEqual(3,
700                         checkpoint_utils.load_variable(self.model_dir,
701                                                        self.global_step.name))
702        mon_sess.run(self.train_op)
703        # Not saved
704        self.assertEqual(3,
705                         checkpoint_utils.load_variable(self.model_dir,
706                                                        self.global_step.name))
707        mon_sess.run(self.train_op)
708        # saved
709        self.assertEqual(5,
710                         checkpoint_utils.load_variable(self.model_dir,
711                                                        self.global_step.name))
712
713  def test_save_saves_at_end(self):
714    with self.graph.as_default():
715      hook = basic_session_run_hooks.CheckpointSaverHook(
716          self.model_dir, save_secs=2, scaffold=self.scaffold)
717      hook.begin()
718      self.scaffold.finalize()
719      with session_lib.Session() as sess:
720        sess.run(self.scaffold.init_op)
721        mon_sess = monitored_session._HookedSession(sess, [hook])
722        mon_sess.run(self.train_op)
723        mon_sess.run(self.train_op)
724        hook.end(sess)
725        self.assertEqual(2,
726                         checkpoint_utils.load_variable(self.model_dir,
727                                                        self.global_step.name))
728
729  def test_summary_writer_defs(self):
730    fake_summary_writer.FakeSummaryWriter.install()
731    writer_cache.FileWriterCache.clear()
732    summary_writer = writer_cache.FileWriterCache.get(self.model_dir)
733
734    with self.graph.as_default():
735      hook = basic_session_run_hooks.CheckpointSaverHook(
736          self.model_dir, save_steps=2, scaffold=self.scaffold)
737      hook.begin()
738      self.scaffold.finalize()
739      with session_lib.Session() as sess:
740        sess.run(self.scaffold.init_op)
741        mon_sess = monitored_session._HookedSession(sess, [hook])
742        hook.after_create_session(sess, None)
743        mon_sess.run(self.train_op)
744      summary_writer.assert_summaries(
745          test_case=self,
746          expected_logdir=self.model_dir,
747          expected_added_meta_graphs=[
748              meta_graph.create_meta_graph_def(
749                  graph_def=self.graph.as_graph_def(add_shapes=True),
750                  saver_def=self.scaffold.saver.saver_def)
751          ])
752
753    fake_summary_writer.FakeSummaryWriter.uninstall()
754
755  def test_save_checkpoint_before_first_train_step(self):
756    with self.graph.as_default():
757      hook = basic_session_run_hooks.CheckpointSaverHook(
758          self.model_dir, save_steps=2, scaffold=self.scaffold)
759      hook.begin()
760      self.scaffold.finalize()
761      with session_lib.Session() as sess:
762        mon_sess = monitored_session._HookedSession(sess, [hook])
763        sess.run(self.scaffold.init_op)
764        hook.after_create_session(sess, None)
765        # Verifies that checkpoint is saved at step 0.
766        self.assertEqual(0,
767                         checkpoint_utils.load_variable(self.model_dir,
768                                                        self.global_step.name))
769        # Verifies that no checkpoint is saved after one training step.
770        mon_sess.run(self.train_op)
771        self.assertEqual(0,
772                         checkpoint_utils.load_variable(self.model_dir,
773                                                        self.global_step.name))
774        # Verifies that checkpoint is saved after save_steps.
775        mon_sess.run(self.train_op)
776        self.assertEqual(2,
777                         checkpoint_utils.load_variable(self.model_dir,
778                                                        self.global_step.name))
779
780
781class CheckpointSaverHookMultiStepTest(test.TestCase):
782
783  def setUp(self):
784    self.model_dir = tempfile.mkdtemp()
785    self.graph = ops.Graph()
786    self.steps_per_run = 5
787    with self.graph.as_default():
788      self.scaffold = monitored_session.Scaffold()
789      self.global_step = variables.get_or_create_global_step()
790      self.train_op = training_util._increment_global_step(self.steps_per_run)
791
792  def tearDown(self):
793    shutil.rmtree(self.model_dir, ignore_errors=True)
794
795  def test_save_steps_saves_in_first_step(self):
796    with self.graph.as_default():
797      hook = basic_session_run_hooks.CheckpointSaverHook(
798          self.model_dir,
799          save_steps=2*self.steps_per_run,
800          scaffold=self.scaffold)
801      hook._set_steps_per_run(self.steps_per_run)
802      hook.begin()
803      self.scaffold.finalize()
804      with session_lib.Session() as sess:
805        sess.run(self.scaffold.init_op)
806        mon_sess = monitored_session._HookedSession(sess, [hook])
807        mon_sess.run(self.train_op)
808        self.assertEqual(5,
809                         checkpoint_utils.load_variable(self.model_dir,
810                                                        self.global_step.name))
811
812  def test_save_steps_saves_periodically(self):
813    with self.graph.as_default():
814      hook = basic_session_run_hooks.CheckpointSaverHook(
815          self.model_dir,
816          save_steps=2*self.steps_per_run,
817          scaffold=self.scaffold)
818      hook._set_steps_per_run(self.steps_per_run)
819      hook.begin()
820      self.scaffold.finalize()
821      with session_lib.Session() as sess:
822        sess.run(self.scaffold.init_op)
823        mon_sess = monitored_session._HookedSession(sess, [hook])
824        mon_sess.run(self.train_op)
825        # Saved (step=5)
826        self.assertEqual(5,
827                         checkpoint_utils.load_variable(self.model_dir,
828                                                        self.global_step.name))
829
830        mon_sess.run(self.train_op)
831        # Not saved (step=10)
832        self.assertEqual(5,
833                         checkpoint_utils.load_variable(self.model_dir,
834                                                        self.global_step.name))
835
836        mon_sess.run(self.train_op)
837        # Saved (step=15)
838        self.assertEqual(15,
839                         checkpoint_utils.load_variable(self.model_dir,
840                                                        self.global_step.name))
841
842        mon_sess.run(self.train_op)
843        # Not saved (step=20)
844        self.assertEqual(15,
845                         checkpoint_utils.load_variable(self.model_dir,
846                                                        self.global_step.name))
847
848        mon_sess.run(self.train_op)
849        # Saved (step=25)
850        self.assertEqual(25,
851                         checkpoint_utils.load_variable(self.model_dir,
852                                                        self.global_step.name))
853
854  def test_save_steps_saves_at_end(self):
855    with self.graph.as_default():
856      hook = basic_session_run_hooks.CheckpointSaverHook(
857          self.model_dir,
858          save_steps=2*self.steps_per_run,
859          scaffold=self.scaffold)
860      hook._set_steps_per_run(self.steps_per_run)
861      hook.begin()
862      self.scaffold.finalize()
863      with session_lib.Session() as sess:
864        sess.run(self.scaffold.init_op)
865        mon_sess = monitored_session._HookedSession(sess, [hook])
866        mon_sess.run(self.train_op)
867        mon_sess.run(self.train_op)
868        hook.end(sess)
869        self.assertEqual(10,
870                         checkpoint_utils.load_variable(self.model_dir,
871                                                        self.global_step.name))
872
873
874class ResourceCheckpointSaverHookTest(test.TestCase):
875
876  def setUp(self):
877    self.model_dir = tempfile.mkdtemp()
878    self.graph = ops.Graph()
879    with self.graph.as_default():
880      self.scaffold = monitored_session.Scaffold()
881      with variable_scope.variable_scope('foo', use_resource=True):
882        self.global_step = training_util.get_or_create_global_step()
883      self.train_op = training_util._increment_global_step(1)
884
885  def test_save_steps_saves_periodically(self):
886    with self.graph.as_default():
887      hook = basic_session_run_hooks.CheckpointSaverHook(
888          self.model_dir, save_steps=2, scaffold=self.scaffold)
889      hook.begin()
890      self.scaffold.finalize()
891      with session_lib.Session() as sess:
892        sess.run(self.scaffold.init_op)
893        mon_sess = monitored_session._HookedSession(sess, [hook])
894        mon_sess.run(self.train_op)
895        mon_sess.run(self.train_op)
896        # Not saved
897        self.assertEqual(1,
898                         checkpoint_utils.load_variable(self.model_dir,
899                                                        self.global_step.name))
900        mon_sess.run(self.train_op)
901        # saved
902        self.assertEqual(3,
903                         checkpoint_utils.load_variable(self.model_dir,
904                                                        self.global_step.name))
905        mon_sess.run(self.train_op)
906        # Not saved
907        self.assertEqual(3,
908                         checkpoint_utils.load_variable(self.model_dir,
909                                                        self.global_step.name))
910        mon_sess.run(self.train_op)
911        # saved
912        self.assertEqual(5,
913                         checkpoint_utils.load_variable(self.model_dir,
914                                                        self.global_step.name))
915
916
917class StepCounterHookTest(test.TestCase):
918
919  def setUp(self):
920    self.log_dir = tempfile.mkdtemp()
921
922  def tearDown(self):
923    shutil.rmtree(self.log_dir, ignore_errors=True)
924
925  @test.mock.patch.object(time, 'time')
926  def test_step_counter_every_n_steps(self, mock_time):
927    mock_time.return_value = MOCK_START_TIME
928    with ops.Graph().as_default() as g, session_lib.Session() as sess:
929      variables.get_or_create_global_step()
930      train_op = training_util._increment_global_step(1)
931      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
932      hook = basic_session_run_hooks.StepCounterHook(
933          summary_writer=summary_writer, every_n_steps=10)
934      hook.begin()
935      self.evaluate(variables_lib.global_variables_initializer())
936      mon_sess = monitored_session._HookedSession(sess, [hook])
937      with test.mock.patch.object(tf_logging, 'warning') as mock_log:
938        for _ in range(30):
939          mock_time.return_value += 0.01
940          mon_sess.run(train_op)
941        # logging.warning should not be called.
942        self.assertIsNone(mock_log.call_args)
943      hook.end(sess)
944      summary_writer.assert_summaries(
945          test_case=self,
946          expected_logdir=self.log_dir,
947          expected_graph=g,
948          expected_summaries={})
949      self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
950      for step in [11, 21]:
951        summary_value = summary_writer.summaries[step][0].value[0]
952        self.assertEqual('global_step/sec', summary_value.tag)
953        self.assertGreater(summary_value.simple_value, 0)
954
955  @test.mock.patch.object(time, 'time')
956  def test_step_counter_every_n_secs(self, mock_time):
957    mock_time.return_value = MOCK_START_TIME
958    with ops.Graph().as_default() as g, session_lib.Session() as sess:
959      variables.get_or_create_global_step()
960      train_op = training_util._increment_global_step(1)
961      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
962      hook = basic_session_run_hooks.StepCounterHook(
963          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)
964
965      hook.begin()
966      self.evaluate(variables_lib.global_variables_initializer())
967      mon_sess = monitored_session._HookedSession(sess, [hook])
968      mon_sess.run(train_op)
969      mock_time.return_value += 0.2
970      mon_sess.run(train_op)
971      mock_time.return_value += 0.2
972      mon_sess.run(train_op)
973      hook.end(sess)
974
975      summary_writer.assert_summaries(
976          test_case=self,
977          expected_logdir=self.log_dir,
978          expected_graph=g,
979          expected_summaries={})
980      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
981      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
982      for summary in summary_writer.summaries.values():
983        summary_value = summary[0].value[0]
984        self.assertEqual('global_step/sec', summary_value.tag)
985        self.assertGreater(summary_value.simple_value, 0)
986
987  def test_global_step_name(self):
988    with ops.Graph().as_default() as g, session_lib.Session() as sess:
989      with variable_scope.variable_scope('bar'):
990        variable_scope.get_variable(
991            'foo',
992            initializer=0,
993            trainable=False,
994            collections=[
995                ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
996            ])
997      train_op = training_util._increment_global_step(1)
998      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
999      hook = basic_session_run_hooks.StepCounterHook(
1000          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)
1001
1002      hook.begin()
1003      self.evaluate(variables_lib.global_variables_initializer())
1004      mon_sess = monitored_session._HookedSession(sess, [hook])
1005      mon_sess.run(train_op)
1006      mon_sess.run(train_op)
1007      hook.end(sess)
1008
1009      summary_writer.assert_summaries(
1010          test_case=self,
1011          expected_logdir=self.log_dir,
1012          expected_graph=g,
1013          expected_summaries={})
1014      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
1015      self.assertItemsEqual([2], summary_writer.summaries.keys())
1016      summary_value = summary_writer.summaries[2][0].value[0]
1017      self.assertEqual('bar/foo/sec', summary_value.tag)
1018
1019  def test_log_warning_if_global_step_not_increased(self):
1020    with ops.Graph().as_default(), session_lib.Session() as sess:
1021      variables.get_or_create_global_step()
1022      train_op = training_util._increment_global_step(0)  # keep same.
1023      self.evaluate(variables_lib.global_variables_initializer())
1024      hook = basic_session_run_hooks.StepCounterHook(
1025          every_n_steps=1, every_n_secs=None)
1026      hook.begin()
1027      mon_sess = monitored_session._HookedSession(sess, [hook])
1028      mon_sess.run(train_op)  # Run one step to record global step.
1029      with test.mock.patch.object(tf_logging, 'warning') as mock_log:
1030        for _ in range(30):
1031          mon_sess.run(train_op)
1032        self.assertRegexpMatches(
1033            str(mock_log.call_args),
1034            'global step.*has not been increased')
1035      hook.end(sess)
1036
1037  def _setup_steps_per_run_test(self,
1038                                every_n_steps,
1039                                steps_per_run,
1040                                graph,
1041                                sess):
1042    variables.get_or_create_global_step()
1043    self.train_op = training_util._increment_global_step(steps_per_run)
1044    self.summary_writer = fake_summary_writer.FakeSummaryWriter(
1045        self.log_dir, graph)
1046    self.hook = basic_session_run_hooks.StepCounterHook(
1047        summary_writer=self.summary_writer, every_n_steps=every_n_steps)
1048    self.hook._set_steps_per_run(steps_per_run)
1049    self.hook.begin()
1050    self.evaluate(variables_lib.global_variables_initializer())
1051    self.mon_sess = monitored_session._HookedSession(sess, [self.hook])
1052
1053  @test.mock.patch.object(time, 'time')
1054  def test_steps_per_run_less_than_every_n_steps(self, mock_time):
1055    mock_time.return_value = MOCK_START_TIME
1056    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1057      self._setup_steps_per_run_test(10, 5, g, sess)
1058
1059      # Logs at 15, 25
1060      for _ in range(5):
1061        mock_time.return_value += 0.01
1062        self.mon_sess.run(self.train_op)
1063
1064      self.hook.end(sess)
1065      self.summary_writer.assert_summaries(
1066          test_case=self,
1067          expected_logdir=self.log_dir,
1068          expected_graph=g,
1069          expected_summaries={})
1070      self.assertItemsEqual([15, 25], self.summary_writer.summaries.keys())
1071      for step in [15, 25]:
1072        summary_value = self.summary_writer.summaries[step][0].value[0]
1073        self.assertEqual('global_step/sec', summary_value.tag)
1074        self.assertGreater(summary_value.simple_value, 0)
1075
1076  @test.mock.patch.object(time, 'time')
1077  def test_steps_per_run_equal_every_n_steps(self, mock_time):
1078    mock_time.return_value = MOCK_START_TIME
1079    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1080      self._setup_steps_per_run_test(5, 5, g, sess)
1081
1082      # Logs at 10, 15, 20, 25
1083      for _ in range(5):
1084        mock_time.return_value += 0.01
1085        self.mon_sess.run(self.train_op)
1086
1087      self.hook.end(sess)
1088      self.summary_writer.assert_summaries(
1089          test_case=self,
1090          expected_logdir=self.log_dir,
1091          expected_graph=g,
1092          expected_summaries={})
1093      self.assertItemsEqual([10, 15, 20, 25],
1094                            self.summary_writer.summaries.keys())
1095      for step in [10, 15, 20, 25]:
1096        summary_value = self.summary_writer.summaries[step][0].value[0]
1097        self.assertEqual('global_step/sec', summary_value.tag)
1098        self.assertGreater(summary_value.simple_value, 0)
1099
1100  @test.mock.patch.object(time, 'time')
1101  def test_steps_per_run_greater_than_every_n_steps(self, mock_time):
1102    mock_time.return_value = MOCK_START_TIME
1103    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1104      self._setup_steps_per_run_test(5, 10, g, sess)
1105
1106      # Logs at 20, 30, 40, 50
1107      for _ in range(5):
1108        mock_time.return_value += 0.01
1109        self.mon_sess.run(self.train_op)
1110
1111      self.hook.end(sess)
1112      self.summary_writer.assert_summaries(
1113          test_case=self,
1114          expected_logdir=self.log_dir,
1115          expected_graph=g,
1116          expected_summaries={})
1117      self.assertItemsEqual([20, 30, 40, 50],
1118                            self.summary_writer.summaries.keys())
1119      for step in [20, 30, 40, 50]:
1120        summary_value = self.summary_writer.summaries[step][0].value[0]
1121        self.assertEqual('global_step/sec', summary_value.tag)
1122        self.assertGreater(summary_value.simple_value, 0)
1123
1124
1125@test_util.run_deprecated_v1
1126class SummarySaverHookTest(test.TestCase):
1127
1128  def setUp(self):
1129    test.TestCase.setUp(self)
1130
1131    self.log_dir = 'log/dir'
1132    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
1133
1134    var = variables_lib.Variable(0.0)
1135    tensor = state_ops.assign_add(var, 1.0)
1136    tensor2 = tensor * 2
1137    self.summary_op = summary_lib.scalar('my_summary', tensor)
1138    self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
1139
1140    variables.get_or_create_global_step()
1141    self.train_op = training_util._increment_global_step(1)
1142
1143  def test_raise_when_scaffold_and_summary_op_both_missing(self):
1144    with self.assertRaises(ValueError):
1145      basic_session_run_hooks.SummarySaverHook()
1146
1147  def test_raise_when_scaffold_and_summary_op_both_present(self):
1148    with self.assertRaises(ValueError):
1149      basic_session_run_hooks.SummarySaverHook(
1150          scaffold=monitored_session.Scaffold(), summary_op=self.summary_op)
1151
1152  def test_raise_in_both_secs_and_steps(self):
1153    with self.assertRaises(ValueError):
1154      basic_session_run_hooks.SummarySaverHook(
1155          save_secs=10, save_steps=20, summary_writer=self.summary_writer)
1156
1157  def test_raise_in_none_secs_and_steps(self):
1158    with self.assertRaises(ValueError):
1159      basic_session_run_hooks.SummarySaverHook(
1160          save_secs=None, save_steps=None, summary_writer=self.summary_writer)
1161
1162  def test_save_steps(self):
1163    hook = basic_session_run_hooks.SummarySaverHook(
1164        save_steps=8,
1165        summary_writer=self.summary_writer,
1166        summary_op=self.summary_op)
1167
1168    with self.cached_session() as sess:
1169      hook.begin()
1170      self.evaluate(variables_lib.global_variables_initializer())
1171      mon_sess = monitored_session._HookedSession(sess, [hook])
1172      for _ in range(30):
1173        mon_sess.run(self.train_op)
1174      hook.end(sess)
1175
1176    self.summary_writer.assert_summaries(
1177        test_case=self,
1178        expected_logdir=self.log_dir,
1179        expected_summaries={
1180            1: {
1181                'my_summary': 1.0
1182            },
1183            9: {
1184                'my_summary': 2.0
1185            },
1186            17: {
1187                'my_summary': 3.0
1188            },
1189            25: {
1190                'my_summary': 4.0
1191            },
1192        })
1193
1194  def test_multiple_summaries(self):
1195    hook = basic_session_run_hooks.SummarySaverHook(
1196        save_steps=8,
1197        summary_writer=self.summary_writer,
1198        summary_op=[self.summary_op, self.summary_op2])
1199
1200    with self.cached_session() as sess:
1201      hook.begin()
1202      self.evaluate(variables_lib.global_variables_initializer())
1203      mon_sess = monitored_session._HookedSession(sess, [hook])
1204      for _ in range(10):
1205        mon_sess.run(self.train_op)
1206      hook.end(sess)
1207
1208    self.summary_writer.assert_summaries(
1209        test_case=self,
1210        expected_logdir=self.log_dir,
1211        expected_summaries={
1212            1: {
1213                'my_summary': 1.0,
1214                'my_summary2': 2.0
1215            },
1216            9: {
1217                'my_summary': 2.0,
1218                'my_summary2': 4.0
1219            },
1220        })
1221
1222  @test.mock.patch.object(time, 'time')
1223  def test_save_secs_saving_once_every_step(self, mock_time):
1224    mock_time.return_value = MOCK_START_TIME
1225    hook = basic_session_run_hooks.SummarySaverHook(
1226        save_secs=0.5,
1227        summary_writer=self.summary_writer,
1228        summary_op=self.summary_op)
1229
1230    with self.cached_session() as sess:
1231      hook.begin()
1232      self.evaluate(variables_lib.global_variables_initializer())
1233      mon_sess = monitored_session._HookedSession(sess, [hook])
1234      for _ in range(4):
1235        mon_sess.run(self.train_op)
1236        mock_time.return_value += 0.5
1237      hook.end(sess)
1238
1239    self.summary_writer.assert_summaries(
1240        test_case=self,
1241        expected_logdir=self.log_dir,
1242        expected_summaries={
1243            1: {
1244                'my_summary': 1.0
1245            },
1246            2: {
1247                'my_summary': 2.0
1248            },
1249            3: {
1250                'my_summary': 3.0
1251            },
1252            4: {
1253                'my_summary': 4.0
1254            },
1255        })
1256
1257  @test.mock.patch.object(time, 'time')
1258  def test_save_secs_saving_once_every_three_steps(self, mock_time):
1259    mock_time.return_value = 1484695987.209386
1260    hook = basic_session_run_hooks.SummarySaverHook(
1261        save_secs=9.,
1262        summary_writer=self.summary_writer,
1263        summary_op=self.summary_op)
1264
1265    with self.cached_session() as sess:
1266      hook.begin()
1267      self.evaluate(variables_lib.global_variables_initializer())
1268      mon_sess = monitored_session._HookedSession(sess, [hook])
1269      for _ in range(8):
1270        mon_sess.run(self.train_op)
1271        mock_time.return_value += 3.1
1272      hook.end(sess)
1273
1274    # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
1275    self.summary_writer.assert_summaries(
1276        test_case=self,
1277        expected_logdir=self.log_dir,
1278        expected_summaries={
1279            1: {
1280                'my_summary': 1.0
1281            },
1282            4: {
1283                'my_summary': 2.0
1284            },
1285            7: {
1286                'my_summary': 3.0
1287            },
1288        })
1289
1290
1291class GlobalStepWaiterHookTest(test.TestCase):
1292
1293  def test_not_wait_for_step_zero(self):
1294    with ops.Graph().as_default():
1295      variables.get_or_create_global_step()
1296      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
1297      hook.begin()
1298      with session_lib.Session() as sess:
1299        # Before run should return without waiting gstep increment.
1300        hook.before_run(
1301            session_run_hook.SessionRunContext(
1302                original_args=None, session=sess))
1303
1304  @test.mock.patch.object(time, 'sleep')
1305  def test_wait_for_step(self, mock_sleep):
1306    with ops.Graph().as_default():
1307      gstep = variables.get_or_create_global_step()
1308      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
1309      hook.begin()
1310
1311      with session_lib.Session() as sess:
1312        # Mock out calls to time.sleep() to update the global step.
1313
1314        class Context(object):
1315          counter = 0
1316
1317        def mock_sleep_side_effect(seconds):
1318          del seconds  # argument is ignored
1319          Context.counter += 1
1320          if Context.counter == 1:
1321            # The first time sleep() is called, we update the global_step from
1322            # 0 to 500.
1323            sess.run(state_ops.assign(gstep, 500))
1324          elif Context.counter == 2:
1325            # The second time sleep() is called, we update the global_step from
1326            # 500 to 1100.
1327            sess.run(state_ops.assign(gstep, 1100))
1328          else:
1329            raise AssertionError(
1330                'Expected before_run() to terminate after the second call to '
1331                'time.sleep()')
1332
1333        mock_sleep.side_effect = mock_sleep_side_effect
1334
1335        # Run the mocked-out interaction with the hook.
1336        self.evaluate(variables_lib.global_variables_initializer())
1337        run_context = session_run_hook.SessionRunContext(
1338            original_args=None, session=sess)
1339        hook.before_run(run_context)
1340        self.assertEqual(Context.counter, 2)
1341
1342
1343class FinalOpsHookTest(test.TestCase):
1344
1345  def test_final_ops_is_scalar_tensor(self):
1346    with ops.Graph().as_default():
1347      expected_value = 4
1348      final_ops = constant_op.constant(expected_value)
1349
1350      hook = basic_session_run_hooks.FinalOpsHook(final_ops)
1351      hook.begin()
1352
1353      with session_lib.Session() as session:
1354        hook.end(session)
1355        self.assertEqual(expected_value,
1356                         hook.final_ops_values)
1357
1358  def test_final_ops_is_tensor(self):
1359    with ops.Graph().as_default():
1360      expected_values = [1, 6, 3, 5, 2, 4]
1361      final_ops = constant_op.constant(expected_values)
1362
1363      hook = basic_session_run_hooks.FinalOpsHook(final_ops)
1364      hook.begin()
1365
1366      with session_lib.Session() as session:
1367        hook.end(session)
1368        self.assertListEqual(expected_values,
1369                             hook.final_ops_values.tolist())
1370
1371  def test_final_ops_triggers_out_of_range_error(self):
1372    with ops.Graph().as_default():
1373      dataset = dataset_ops.Dataset.range(1)
1374      iterator = dataset_ops.make_one_shot_iterator(dataset)
1375      read_ops = iterator.get_next()
1376      final_ops = read_ops
1377
1378      hook = basic_session_run_hooks.FinalOpsHook(final_ops)
1379      hook.begin()
1380
1381      with session_lib.Session() as session:
1382        session.run(read_ops)
1383        with test.mock.patch.object(tf_logging, 'warning') as mock_log:
1384          with self.assertRaisesRegexp(errors.OutOfRangeError,
1385                                       'End of sequence'):
1386            hook.end(session)
1387          self.assertRegexpMatches(
1388              str(mock_log.call_args),
1389              'dependency back to some input source')
1390
1391  def test_final_ops_with_dictionary(self):
1392    with ops.Graph().as_default():
1393      expected_values = [4, -3]
1394      final_ops = array_ops.placeholder(dtype=dtypes.float32)
1395      final_ops_feed_dict = {final_ops: expected_values}
1396
1397      hook = basic_session_run_hooks.FinalOpsHook(
1398          final_ops, final_ops_feed_dict)
1399      hook.begin()
1400
1401      with session_lib.Session() as session:
1402        hook.end(session)
1403        self.assertListEqual(expected_values,
1404                             hook.final_ops_values.tolist())
1405
1406
1407@test_util.run_deprecated_v1
1408class ResourceSummarySaverHookTest(test.TestCase):
1409
1410  def setUp(self):
1411    test.TestCase.setUp(self)
1412
1413    self.log_dir = 'log/dir'
1414    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
1415
1416    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
1417    tensor = state_ops.assign_add(var, 1.0)
1418    self.summary_op = summary_lib.scalar('my_summary', tensor)
1419
1420    with variable_scope.variable_scope('foo', use_resource=True):
1421      variables.create_global_step()
1422    self.train_op = training_util._increment_global_step(1)
1423
1424  def test_save_steps(self):
1425    hook = basic_session_run_hooks.SummarySaverHook(
1426        save_steps=8,
1427        summary_writer=self.summary_writer,
1428        summary_op=self.summary_op)
1429
1430    with self.cached_session() as sess:
1431      hook.begin()
1432      self.evaluate(variables_lib.global_variables_initializer())
1433      mon_sess = monitored_session._HookedSession(sess, [hook])
1434      for _ in range(30):
1435        mon_sess.run(self.train_op)
1436      hook.end(sess)
1437
1438    self.summary_writer.assert_summaries(
1439        test_case=self,
1440        expected_logdir=self.log_dir,
1441        expected_summaries={
1442            1: {
1443                'my_summary': 1.0
1444            },
1445            9: {
1446                'my_summary': 2.0
1447            },
1448            17: {
1449                'my_summary': 3.0
1450            },
1451            25: {
1452                'my_summary': 4.0
1453            },
1454        })
1455
1456
1457class FeedFnHookTest(test.TestCase):
1458
1459  def test_feeding_placeholder(self):
1460    with ops.Graph().as_default(), session_lib.Session() as sess:
1461      x = array_ops.placeholder(dtype=dtypes.float32)
1462      y = x + 1
1463      hook = basic_session_run_hooks.FeedFnHook(
1464          feed_fn=lambda: {x: 1.0})
1465      hook.begin()
1466      mon_sess = monitored_session._HookedSession(sess, [hook])
1467      self.assertEqual(mon_sess.run(y), 2)
1468
1469
1470class ProfilerHookTest(test.TestCase):
1471
1472  def setUp(self):
1473    super(ProfilerHookTest, self).setUp()
1474    self.output_dir = tempfile.mkdtemp()
1475    self.graph = ops.Graph()
1476    self.filepattern = os.path.join(self.output_dir, 'timeline-*.json')
1477    with self.graph.as_default():
1478      self.global_step = variables.get_or_create_global_step()
1479      self.train_op = state_ops.assign_add(self.global_step, 1)
1480
1481  def tearDown(self):
1482    super(ProfilerHookTest, self).tearDown()
1483    shutil.rmtree(self.output_dir, ignore_errors=True)
1484
1485  def _count_timeline_files(self):
1486    return len(gfile.Glob(self.filepattern))
1487
1488  @test_util.run_deprecated_v1
1489  def test_raise_in_both_secs_and_steps(self):
1490    with self.assertRaises(ValueError):
1491      basic_session_run_hooks.ProfilerHook(save_secs=10, save_steps=20)
1492
1493  @test_util.run_deprecated_v1
1494  def test_raise_in_none_secs_and_steps(self):
1495    with self.assertRaises(ValueError):
1496      basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
1497
1498  def test_save_secs_does_not_save_in_first_step(self):
1499    with self.graph.as_default():
1500      hook = basic_session_run_hooks.ProfilerHook(
1501          save_secs=2, output_dir=self.output_dir)
1502      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1503        sess.run(self.train_op)
1504        self.assertEqual(0, self._count_timeline_files())
1505
1506  @test.mock.patch.object(time, 'time')
1507  def test_save_secs_saves_periodically(self, mock_time):
1508    # Pick a fixed start time.
1509    with self.graph.as_default():
1510      mock_time.return_value = MOCK_START_TIME
1511      hook = basic_session_run_hooks.ProfilerHook(
1512          save_secs=2, output_dir=self.output_dir)
1513      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1514        sess.run(self.train_op)  # Not saved.
1515        self.assertEqual(0, self._count_timeline_files())
1516        # Simulate 2.5 seconds of sleep.
1517        mock_time.return_value = MOCK_START_TIME + 2.5
1518        sess.run(self.train_op)  # Saved.
1519        self.assertEqual(1, self._count_timeline_files())
1520
1521        # Pretend some small amount of time has passed.
1522        mock_time.return_value = MOCK_START_TIME + 2.6
1523        sess.run(self.train_op)  # Not saved.
1524        # Edge test just before we should save the timeline.
1525        mock_time.return_value = MOCK_START_TIME + 4.4
1526        sess.run(self.train_op)  # Not saved.
1527        self.assertEqual(1, self._count_timeline_files())
1528
1529        mock_time.return_value = MOCK_START_TIME + 4.5
1530        sess.run(self.train_op)  # Saved.
1531        self.assertEqual(2, self._count_timeline_files())
1532
1533  def test_save_steps_does_not_save_in_first_step(self):
1534    with self.graph.as_default():
1535      hook = basic_session_run_hooks.ProfilerHook(
1536          save_steps=1, output_dir=self.output_dir)
1537      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1538        sess.run(self.train_op)  # Not saved.
1539        self.assertEqual(0, self._count_timeline_files())
1540
1541  def test_save_steps_saves_periodically(self):
1542    with self.graph.as_default():
1543      hook = basic_session_run_hooks.ProfilerHook(
1544          save_steps=2, output_dir=self.output_dir)
1545      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1546        self.assertEqual(0, self._count_timeline_files())
1547        sess.run(self.train_op)  # Not saved.
1548        self.assertEqual(0, self._count_timeline_files())
1549        sess.run(self.train_op)  # Saved.
1550        self.assertEqual(1, self._count_timeline_files())
1551        sess.run(self.train_op)  # Not saved.
1552        self.assertEqual(1, self._count_timeline_files())
1553        sess.run(self.train_op)  # Saved.
1554        self.assertEqual(2, self._count_timeline_files())
1555        sess.run(self.train_op)  # Not saved.
1556        self.assertEqual(2, self._count_timeline_files())
1557
1558  def test_run_metadata_saves(self):
1559    writer_cache.FileWriterCache.clear()
1560    fake_summary_writer.FakeSummaryWriter.install()
1561    fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
1562    with self.graph.as_default():
1563      hook = basic_session_run_hooks.ProfilerHook(
1564          save_steps=1, output_dir=self.output_dir)
1565      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1566        sess.run(self.train_op)  # Not saved.
1567        sess.run(self.train_op)  # Saved.
1568        self.assertEqual(
1569            list(fake_writer._added_run_metadata.keys()), ['step_2'])
1570    fake_summary_writer.FakeSummaryWriter.uninstall()
1571
1572
1573if __name__ == '__main__':
1574  test.main()
1575