• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pylint: disable=g-bad-file-header
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for basic_session_run_hooks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os.path
23import shutil
24import tempfile
25import time
26
27from tensorflow.python.client import session as session_lib
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import meta_graph
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.ops import variables as variables_lib
40import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
41from tensorflow.python.platform import gfile
42from tensorflow.python.platform import test
43from tensorflow.python.platform import tf_logging
44from tensorflow.python.summary import summary as summary_lib
45from tensorflow.python.summary.writer import fake_summary_writer
46from tensorflow.python.summary.writer import writer_cache
47from tensorflow.python.training import basic_session_run_hooks
48from tensorflow.python.training import checkpoint_utils
49from tensorflow.python.training import monitored_session
50from tensorflow.python.training import session_run_hook
51from tensorflow.python.training import training_util
52
53
54# Provide a realistic start time for unit tests where we need to mock out
55# calls to time.time().
56MOCK_START_TIME = 1484695987.209386
57
58
59class MockCheckpointSaverListener(
60    basic_session_run_hooks.CheckpointSaverListener):
61
62  def __init__(self):
63    self.begin_count = 0
64    self.before_save_count = 0
65    self.after_save_count = 0
66    self.end_count = 0
67    self.ask_for_stop = False
68
69  def begin(self):
70    self.begin_count += 1
71
72  def before_save(self, session, global_step):
73    self.before_save_count += 1
74
75  def after_save(self, session, global_step):
76    self.after_save_count += 1
77    if self.ask_for_stop:
78      return True
79
80  def end(self, session, global_step):
81    self.end_count += 1
82
83  def get_counts(self):
84    return {
85        'begin': self.begin_count,
86        'before_save': self.before_save_count,
87        'after_save': self.after_save_count,
88        'end': self.end_count
89    }
90
91
92class SecondOrStepTimerTest(test.TestCase):
93
94  @test_util.run_deprecated_v1
95  def test_raise_in_both_secs_and_steps(self):
96    with self.assertRaises(ValueError):
97      basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10)
98
99  @test_util.run_deprecated_v1
100  def test_raise_in_none_secs_and_steps(self):
101    with self.assertRaises(ValueError):
102      basic_session_run_hooks.SecondOrStepTimer()
103
104  @test.mock.patch.object(time, 'time')
105  def test_every_secs(self, mock_time):
106    mock_time.return_value = MOCK_START_TIME
107    timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)
108    self.assertTrue(timer.should_trigger_for_step(1))
109
110    timer.update_last_triggered_step(1)
111    self.assertFalse(timer.should_trigger_for_step(1))
112    self.assertFalse(timer.should_trigger_for_step(2))
113
114    mock_time.return_value += 1.0
115    self.assertFalse(timer.should_trigger_for_step(1))
116    self.assertTrue(timer.should_trigger_for_step(2))
117
118  def test_every_steps(self):
119    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3)
120    self.assertTrue(timer.should_trigger_for_step(1))
121
122    timer.update_last_triggered_step(1)
123    self.assertFalse(timer.should_trigger_for_step(1))
124    self.assertFalse(timer.should_trigger_for_step(2))
125    self.assertFalse(timer.should_trigger_for_step(3))
126    self.assertTrue(timer.should_trigger_for_step(4))
127
128  def test_update_last_triggered_step(self):
129    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1)
130
131    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1)
132    self.assertEqual(None, elapsed_secs)
133    self.assertEqual(None, elapsed_steps)
134
135    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5)
136    self.assertLess(0, elapsed_secs)
137    self.assertEqual(4, elapsed_steps)
138
139    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7)
140    self.assertLess(0, elapsed_secs)
141    self.assertEqual(2, elapsed_steps)
142
143
144class StopAtStepTest(test.TestCase):
145
146  def test_raise_in_both_last_step_and_num_steps(self):
147    with self.assertRaises(ValueError):
148      basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20)
149
150  def test_stop_based_on_last_step(self):
151    h = basic_session_run_hooks.StopAtStepHook(last_step=10)
152    with ops.Graph().as_default():
153      global_step = training_util.get_or_create_global_step()
154      no_op = control_flow_ops.no_op()
155      h.begin()
156      with session_lib.Session() as sess:
157        mon_sess = monitored_session._HookedSession(sess, [h])
158        sess.run(state_ops.assign(global_step, 5))
159        h.after_create_session(sess, None)
160        mon_sess.run(no_op)
161        self.assertFalse(mon_sess.should_stop())
162        sess.run(state_ops.assign(global_step, 9))
163        mon_sess.run(no_op)
164        self.assertFalse(mon_sess.should_stop())
165        sess.run(state_ops.assign(global_step, 10))
166        mon_sess.run(no_op)
167        self.assertTrue(mon_sess.should_stop())
168        sess.run(state_ops.assign(global_step, 11))
169        mon_sess._should_stop = False
170        mon_sess.run(no_op)
171        self.assertTrue(mon_sess.should_stop())
172
173  def test_stop_based_on_num_step(self):
174    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
175
176    with ops.Graph().as_default():
177      global_step = training_util.get_or_create_global_step()
178      no_op = control_flow_ops.no_op()
179      h.begin()
180      with session_lib.Session() as sess:
181        mon_sess = monitored_session._HookedSession(sess, [h])
182        sess.run(state_ops.assign(global_step, 5))
183        h.after_create_session(sess, None)
184        mon_sess.run(no_op)
185        self.assertFalse(mon_sess.should_stop())
186        sess.run(state_ops.assign(global_step, 13))
187        mon_sess.run(no_op)
188        self.assertFalse(mon_sess.should_stop())
189        sess.run(state_ops.assign(global_step, 14))
190        mon_sess.run(no_op)
191        self.assertFalse(mon_sess.should_stop())
192        sess.run(state_ops.assign(global_step, 15))
193        mon_sess.run(no_op)
194        self.assertTrue(mon_sess.should_stop())
195        sess.run(state_ops.assign(global_step, 16))
196        mon_sess._should_stop = False
197        mon_sess.run(no_op)
198        self.assertTrue(mon_sess.should_stop())
199
200  def test_stop_based_with_multiple_steps(self):
201    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
202
203    with ops.Graph().as_default():
204      global_step = training_util.get_or_create_global_step()
205      no_op = control_flow_ops.no_op()
206      h.begin()
207      with session_lib.Session() as sess:
208        mon_sess = monitored_session._HookedSession(sess, [h])
209        sess.run(state_ops.assign(global_step, 5))
210        h.after_create_session(sess, None)
211        mon_sess.run(no_op)
212        self.assertFalse(mon_sess.should_stop())
213        sess.run(state_ops.assign(global_step, 15))
214        mon_sess.run(no_op)
215        self.assertTrue(mon_sess.should_stop())
216
217
218class LoggingTensorHookTest(test.TestCase):
219
220  def setUp(self):
221    # Mock out logging calls so we can verify whether correct tensors are being
222    # monitored.
223    self._actual_log = tf_logging.info
224    self.logged_message = None
225
226    def mock_log(*args, **kwargs):
227      self.logged_message = args
228      self._actual_log(*args, **kwargs)
229
230    tf_logging.info = mock_log
231
232  def tearDown(self):
233    tf_logging.info = self._actual_log
234
235  def test_illegal_args(self):
236    with self.assertRaisesRegex(ValueError, 'nvalid every_n_iter'):
237      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0)
238    with self.assertRaisesRegex(ValueError, 'nvalid every_n_iter'):
239      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10)
240    with self.assertRaisesRegex(ValueError, 'xactly one of'):
241      basic_session_run_hooks.LoggingTensorHook(
242          tensors=['t'], every_n_iter=5, every_n_secs=5)
243    with self.assertRaisesRegex(ValueError, 'xactly one of'):
244      basic_session_run_hooks.LoggingTensorHook(tensors=['t'])
245
246  def test_print_at_end_only(self):
247    with ops.Graph().as_default(), session_lib.Session() as sess:
248      t = constant_op.constant(42.0, name='foo')
249      train_op = constant_op.constant(3)
250      hook = basic_session_run_hooks.LoggingTensorHook(
251          tensors=[t.name], at_end=True)
252      hook.begin()
253      mon_sess = monitored_session._HookedSession(sess, [hook])
254      self.evaluate(variables_lib.global_variables_initializer())
255      self.logged_message = ''
256      for _ in range(3):
257        mon_sess.run(train_op)
258        # assertNotRegexpMatches is not supported by python 3.1 and later
259        self.assertEqual(str(self.logged_message).find(t.name), -1)
260
261      hook.end(sess)
262      self.assertRegex(str(self.logged_message), t.name)
263
264  def _validate_print_every_n_steps(self, sess, at_end):
265    t = constant_op.constant(42.0, name='foo')
266
267    train_op = constant_op.constant(3)
268    hook = basic_session_run_hooks.LoggingTensorHook(
269        tensors=[t.name], every_n_iter=10, at_end=at_end)
270    hook.begin()
271    mon_sess = monitored_session._HookedSession(sess, [hook])
272    self.evaluate(variables_lib.global_variables_initializer())
273    mon_sess.run(train_op)
274    self.assertRegex(str(self.logged_message), t.name)
275    for _ in range(3):
276      self.logged_message = ''
277      for _ in range(9):
278        mon_sess.run(train_op)
279        # assertNotRegexpMatches is not supported by python 3.1 and later
280        self.assertEqual(str(self.logged_message).find(t.name), -1)
281      mon_sess.run(train_op)
282      self.assertRegex(str(self.logged_message), t.name)
283
284    # Add additional run to verify proper reset when called multiple times.
285    self.logged_message = ''
286    mon_sess.run(train_op)
287    # assertNotRegexpMatches is not supported by python 3.1 and later
288    self.assertEqual(str(self.logged_message).find(t.name), -1)
289
290    self.logged_message = ''
291    hook.end(sess)
292    if at_end:
293      self.assertRegex(str(self.logged_message), t.name)
294    else:
295      # assertNotRegexpMatches is not supported by python 3.1 and later
296      self.assertEqual(str(self.logged_message).find(t.name), -1)
297
298  def test_print_every_n_steps(self):
299    with ops.Graph().as_default(), session_lib.Session() as sess:
300      self._validate_print_every_n_steps(sess, at_end=False)
301      # Verify proper reset.
302      self._validate_print_every_n_steps(sess, at_end=False)
303
304  def test_print_every_n_steps_and_end(self):
305    with ops.Graph().as_default(), session_lib.Session() as sess:
306      self._validate_print_every_n_steps(sess, at_end=True)
307      # Verify proper reset.
308      self._validate_print_every_n_steps(sess, at_end=True)
309
310  def test_print_first_step(self):
311    # if it runs every iteration, first iteration has None duration.
312    with ops.Graph().as_default(), session_lib.Session() as sess:
313      t = constant_op.constant(42.0, name='foo')
314      train_op = constant_op.constant(3)
315      hook = basic_session_run_hooks.LoggingTensorHook(
316          tensors={'foo': t}, every_n_iter=1)
317      hook.begin()
318      mon_sess = monitored_session._HookedSession(sess, [hook])
319      self.evaluate(variables_lib.global_variables_initializer())
320      mon_sess.run(train_op)
321      self.assertRegex(str(self.logged_message), 'foo')
322      # in first run, elapsed time is None.
323      self.assertEqual(str(self.logged_message).find('sec'), -1)
324
325  def _validate_print_every_n_secs(self, sess, at_end, mock_time):
326    t = constant_op.constant(42.0, name='foo')
327    train_op = constant_op.constant(3)
328
329    hook = basic_session_run_hooks.LoggingTensorHook(
330        tensors=[t.name], every_n_secs=1.0, at_end=at_end)
331    hook.begin()
332    mon_sess = monitored_session._HookedSession(sess, [hook])
333    self.evaluate(variables_lib.global_variables_initializer())
334
335    mon_sess.run(train_op)
336    self.assertRegex(str(self.logged_message), t.name)
337
338    # assertNotRegexpMatches is not supported by python 3.1 and later
339    self.logged_message = ''
340    mon_sess.run(train_op)
341    self.assertEqual(str(self.logged_message).find(t.name), -1)
342    mock_time.return_value += 1.0
343
344    self.logged_message = ''
345    mon_sess.run(train_op)
346    self.assertRegex(str(self.logged_message), t.name)
347
348    self.logged_message = ''
349    hook.end(sess)
350    if at_end:
351      self.assertRegex(str(self.logged_message), t.name)
352    else:
353      # assertNotRegexpMatches is not supported by python 3.1 and later
354      self.assertEqual(str(self.logged_message).find(t.name), -1)
355
356  @test.mock.patch.object(time, 'time')
357  def test_print_every_n_secs(self, mock_time):
358    with ops.Graph().as_default(), session_lib.Session() as sess:
359      mock_time.return_value = MOCK_START_TIME
360      self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)
361      # Verify proper reset.
362      self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)
363
364  @test.mock.patch.object(time, 'time')
365  def test_print_every_n_secs_and_end(self, mock_time):
366    with ops.Graph().as_default(), session_lib.Session() as sess:
367      mock_time.return_value = MOCK_START_TIME
368      self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)
369      # Verify proper reset.
370      self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)
371
372  def test_print_formatter(self):
373    with ops.Graph().as_default(), session_lib.Session() as sess:
374      t = constant_op.constant(42.0, name='foo')
375      train_op = constant_op.constant(3)
376      hook = basic_session_run_hooks.LoggingTensorHook(
377          tensors=[t.name], every_n_iter=10,
378          formatter=lambda items: 'qqq=%s' % items[t.name])
379      hook.begin()
380      mon_sess = monitored_session._HookedSession(sess, [hook])
381      self.evaluate(variables_lib.global_variables_initializer())
382      mon_sess.run(train_op)
383      self.assertEqual(self.logged_message[0], 'qqq=42.0')
384
385
386class CheckpointSaverHookTest(test.TestCase):
387
388  def setUp(self):
389    self.model_dir = tempfile.mkdtemp()
390    self.graph = ops.Graph()
391    with self.graph.as_default():
392      self.scaffold = monitored_session.Scaffold()
393      self.global_step = training_util.get_or_create_global_step()
394      self.train_op = training_util._increment_global_step(1)
395
396  def tearDown(self):
397    shutil.rmtree(self.model_dir, ignore_errors=True)
398
399  def test_saves_when_saver_and_scaffold_both_missing(self):
400    with self.graph.as_default():
401      hook = basic_session_run_hooks.CheckpointSaverHook(
402          self.model_dir, save_steps=1)
403      hook.begin()
404      self.scaffold.finalize()
405      with session_lib.Session() as sess:
406        sess.run(self.scaffold.init_op)
407        mon_sess = monitored_session._HookedSession(sess, [hook])
408        mon_sess.run(self.train_op)
409        self.assertEqual(1,
410                         checkpoint_utils.load_variable(self.model_dir,
411                                                        self.global_step.name))
412
413  def test_raise_when_saver_and_scaffold_both_present(self):
414    with self.assertRaises(ValueError):
415      basic_session_run_hooks.CheckpointSaverHook(
416          self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)
417
418  @test_util.run_deprecated_v1
419  def test_raise_in_both_secs_and_steps(self):
420    with self.assertRaises(ValueError):
421      basic_session_run_hooks.CheckpointSaverHook(
422          self.model_dir, save_secs=10, save_steps=20)
423
424  @test_util.run_deprecated_v1
425  def test_raise_in_none_secs_and_steps(self):
426    with self.assertRaises(ValueError):
427      basic_session_run_hooks.CheckpointSaverHook(self.model_dir)
428
429  def test_save_secs_saves_in_first_step(self):
430    with self.graph.as_default():
431      hook = basic_session_run_hooks.CheckpointSaverHook(
432          self.model_dir, save_secs=2, scaffold=self.scaffold)
433      hook.begin()
434      self.scaffold.finalize()
435      with session_lib.Session() as sess:
436        sess.run(self.scaffold.init_op)
437        mon_sess = monitored_session._HookedSession(sess, [hook])
438        mon_sess.run(self.train_op)
439        self.assertEqual(1,
440                         checkpoint_utils.load_variable(self.model_dir,
441                                                        self.global_step.name))
442
443  def test_save_secs_calls_listeners_at_begin_and_end(self):
444    with self.graph.as_default():
445      listener = MockCheckpointSaverListener()
446      hook = basic_session_run_hooks.CheckpointSaverHook(
447          self.model_dir,
448          save_secs=2,
449          scaffold=self.scaffold,
450          listeners=[listener])
451      hook.begin()
452      self.scaffold.finalize()
453      with session_lib.Session() as sess:
454        sess.run(self.scaffold.init_op)
455        mon_sess = monitored_session._HookedSession(sess, [hook])
456        mon_sess.run(self.train_op)  # hook runs here
457        mon_sess.run(self.train_op)  # hook won't run here, so it does at end
458        hook.end(sess)  # hook runs here
459      self.assertEqual({
460          'begin': 1,
461          'before_save': 2,
462          'after_save': 2,
463          'end': 1
464      }, listener.get_counts())
465
466  def test_listener_with_monitored_session(self):
467    with ops.Graph().as_default():
468      scaffold = monitored_session.Scaffold()
469      global_step = training_util.get_or_create_global_step()
470      train_op = training_util._increment_global_step(1)
471      listener = MockCheckpointSaverListener()
472      hook = basic_session_run_hooks.CheckpointSaverHook(
473          self.model_dir,
474          save_steps=1,
475          scaffold=scaffold,
476          listeners=[listener])
477      with monitored_session.SingularMonitoredSession(
478          hooks=[hook],
479          scaffold=scaffold,
480          checkpoint_dir=self.model_dir) as sess:
481        sess.run(train_op)
482        sess.run(train_op)
483        global_step_val = sess.raw_session().run(global_step)
484      listener_counts = listener.get_counts()
485    self.assertEqual(2, global_step_val)
486    self.assertEqual({
487        'begin': 1,
488        'before_save': 3,
489        'after_save': 3,
490        'end': 1
491    }, listener_counts)
492
493  def test_listener_stops_training_in_after_save(self):
494    with ops.Graph().as_default():
495      scaffold = monitored_session.Scaffold()
496      training_util.get_or_create_global_step()
497      train_op = training_util._increment_global_step(1)
498      listener = MockCheckpointSaverListener()
499      hook = basic_session_run_hooks.CheckpointSaverHook(
500          self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener])
501      with monitored_session.SingularMonitoredSession(
502          hooks=[hook], scaffold=scaffold,
503          checkpoint_dir=self.model_dir) as sess:
504        sess.run(train_op)
505        self.assertFalse(sess.should_stop())
506        sess.run(train_op)
507        self.assertFalse(sess.should_stop())
508        listener.ask_for_stop = True
509        sess.run(train_op)
510        self.assertTrue(sess.should_stop())
511
512  def test_listener_with_default_saver(self):
513    with ops.Graph().as_default():
514      global_step = training_util.get_or_create_global_step()
515      train_op = training_util._increment_global_step(1)
516      listener = MockCheckpointSaverListener()
517      hook = basic_session_run_hooks.CheckpointSaverHook(
518          self.model_dir,
519          save_steps=1,
520          listeners=[listener])
521      with monitored_session.SingularMonitoredSession(
522          hooks=[hook],
523          checkpoint_dir=self.model_dir) as sess:
524        sess.run(train_op)
525        sess.run(train_op)
526        global_step_val = sess.raw_session().run(global_step)
527      listener_counts = listener.get_counts()
528    self.assertEqual(2, global_step_val)
529    self.assertEqual({
530        'begin': 1,
531        'before_save': 3,
532        'after_save': 3,
533        'end': 1
534    }, listener_counts)
535
536    with ops.Graph().as_default():
537      global_step = training_util.get_or_create_global_step()
538      with monitored_session.SingularMonitoredSession(
539          checkpoint_dir=self.model_dir) as sess2:
540        global_step_saved_val = sess2.run(global_step)
541    self.assertEqual(2, global_step_saved_val)
542
543  def test_two_listeners_with_default_saver(self):
544    with ops.Graph().as_default():
545      global_step = training_util.get_or_create_global_step()
546      train_op = training_util._increment_global_step(1)
547      listener1 = MockCheckpointSaverListener()
548      listener2 = MockCheckpointSaverListener()
549      hook = basic_session_run_hooks.CheckpointSaverHook(
550          self.model_dir,
551          save_steps=1,
552          listeners=[listener1, listener2])
553      with monitored_session.SingularMonitoredSession(
554          hooks=[hook],
555          checkpoint_dir=self.model_dir) as sess:
556        sess.run(train_op)
557        sess.run(train_op)
558        global_step_val = sess.raw_session().run(global_step)
559      listener1_counts = listener1.get_counts()
560      listener2_counts = listener2.get_counts()
561    self.assertEqual(2, global_step_val)
562    self.assertEqual({
563        'begin': 1,
564        'before_save': 3,
565        'after_save': 3,
566        'end': 1
567    }, listener1_counts)
568    self.assertEqual(listener1_counts, listener2_counts)
569
570    with ops.Graph().as_default():
571      global_step = training_util.get_or_create_global_step()
572      with monitored_session.SingularMonitoredSession(
573          checkpoint_dir=self.model_dir) as sess2:
574        global_step_saved_val = sess2.run(global_step)
575    self.assertEqual(2, global_step_saved_val)
576
577  @test.mock.patch.object(time, 'time')
578  def test_save_secs_saves_periodically(self, mock_time):
579    with self.graph.as_default():
580      mock_time.return_value = MOCK_START_TIME
581      hook = basic_session_run_hooks.CheckpointSaverHook(
582          self.model_dir, save_secs=2, scaffold=self.scaffold)
583      hook.begin()
584      self.scaffold.finalize()
585
586      with session_lib.Session() as sess:
587        sess.run(self.scaffold.init_op)
588        mon_sess = monitored_session._HookedSession(sess, [hook])
589
590        mock_time.return_value = MOCK_START_TIME
591        mon_sess.run(self.train_op)  # Saved.
592
593        mock_time.return_value = MOCK_START_TIME + 0.5
594        mon_sess.run(self.train_op)  # Not saved.
595
596        self.assertEqual(1,
597                         checkpoint_utils.load_variable(self.model_dir,
598                                                        self.global_step.name))
599
600        # Simulate 2.5 seconds of sleep.
601        mock_time.return_value = MOCK_START_TIME + 2.5
602        mon_sess.run(self.train_op)  # Saved.
603
604        mock_time.return_value = MOCK_START_TIME + 2.6
605        mon_sess.run(self.train_op)  # Not saved.
606
607        mock_time.return_value = MOCK_START_TIME + 2.7
608        mon_sess.run(self.train_op)  # Not saved.
609
610        self.assertEqual(3,
611                         checkpoint_utils.load_variable(self.model_dir,
612                                                        self.global_step.name))
613
614        # Simulate 7.5 more seconds of sleep (10 seconds from start.
615        mock_time.return_value = MOCK_START_TIME + 10
616        mon_sess.run(self.train_op)  # Saved.
617        self.assertEqual(6,
618                         checkpoint_utils.load_variable(self.model_dir,
619                                                        self.global_step.name))
620
621  @test.mock.patch.object(time, 'time')
622  def test_save_secs_calls_listeners_periodically(self, mock_time):
623    with self.graph.as_default():
624      mock_time.return_value = MOCK_START_TIME
625      listener = MockCheckpointSaverListener()
626      hook = basic_session_run_hooks.CheckpointSaverHook(
627          self.model_dir,
628          save_secs=2,
629          scaffold=self.scaffold,
630          listeners=[listener])
631      hook.begin()
632      self.scaffold.finalize()
633      with session_lib.Session() as sess:
634        sess.run(self.scaffold.init_op)
635        mon_sess = monitored_session._HookedSession(sess, [hook])
636
637        mock_time.return_value = MOCK_START_TIME + 0.5
638        mon_sess.run(self.train_op)  # hook runs here
639
640        mock_time.return_value = MOCK_START_TIME + 0.5
641        mon_sess.run(self.train_op)
642
643        mock_time.return_value = MOCK_START_TIME + 3.0
644        mon_sess.run(self.train_op)  # hook runs here
645
646        mock_time.return_value = MOCK_START_TIME + 3.5
647        mon_sess.run(self.train_op)
648
649        mock_time.return_value = MOCK_START_TIME + 4.0
650        mon_sess.run(self.train_op)
651
652        mock_time.return_value = MOCK_START_TIME + 6.5
653        mon_sess.run(self.train_op)  # hook runs here
654
655        mock_time.return_value = MOCK_START_TIME + 7.0
656        mon_sess.run(self.train_op)  # hook won't run here, so it does at end
657
658        mock_time.return_value = MOCK_START_TIME + 7.5
659        hook.end(sess)  # hook runs here
660      self.assertEqual({
661          'begin': 1,
662          'before_save': 4,
663          'after_save': 4,
664          'end': 1
665      }, listener.get_counts())
666
667  def test_save_steps_saves_in_first_step(self):
668    with self.graph.as_default():
669      hook = basic_session_run_hooks.CheckpointSaverHook(
670          self.model_dir, save_steps=2, scaffold=self.scaffold)
671      hook.begin()
672      self.scaffold.finalize()
673      with session_lib.Session() as sess:
674        sess.run(self.scaffold.init_op)
675        mon_sess = monitored_session._HookedSession(sess, [hook])
676        mon_sess.run(self.train_op)
677        self.assertEqual(1,
678                         checkpoint_utils.load_variable(self.model_dir,
679                                                        self.global_step.name))
680
681  def test_save_steps_saves_periodically(self):
682    with self.graph.as_default():
683      hook = basic_session_run_hooks.CheckpointSaverHook(
684          self.model_dir, save_steps=2, scaffold=self.scaffold)
685      hook.begin()
686      self.scaffold.finalize()
687      with session_lib.Session() as sess:
688        sess.run(self.scaffold.init_op)
689        mon_sess = monitored_session._HookedSession(sess, [hook])
690        mon_sess.run(self.train_op)
691        mon_sess.run(self.train_op)
692        # Not saved
693        self.assertEqual(1,
694                         checkpoint_utils.load_variable(self.model_dir,
695                                                        self.global_step.name))
696        mon_sess.run(self.train_op)
697        # saved
698        self.assertEqual(3,
699                         checkpoint_utils.load_variable(self.model_dir,
700                                                        self.global_step.name))
701        mon_sess.run(self.train_op)
702        # Not saved
703        self.assertEqual(3,
704                         checkpoint_utils.load_variable(self.model_dir,
705                                                        self.global_step.name))
706        mon_sess.run(self.train_op)
707        # saved
708        self.assertEqual(5,
709                         checkpoint_utils.load_variable(self.model_dir,
710                                                        self.global_step.name))
711
712  def test_save_saves_at_end(self):
713    with self.graph.as_default():
714      hook = basic_session_run_hooks.CheckpointSaverHook(
715          self.model_dir, save_secs=2, scaffold=self.scaffold)
716      hook.begin()
717      self.scaffold.finalize()
718      with session_lib.Session() as sess:
719        sess.run(self.scaffold.init_op)
720        mon_sess = monitored_session._HookedSession(sess, [hook])
721        mon_sess.run(self.train_op)
722        mon_sess.run(self.train_op)
723        hook.end(sess)
724        self.assertEqual(2,
725                         checkpoint_utils.load_variable(self.model_dir,
726                                                        self.global_step.name))
727
728  def test_summary_writer_defs(self):
729    fake_summary_writer.FakeSummaryWriter.install()
730    writer_cache.FileWriterCache.clear()
731    summary_writer = writer_cache.FileWriterCache.get(self.model_dir)
732
733    with self.graph.as_default():
734      hook = basic_session_run_hooks.CheckpointSaverHook(
735          self.model_dir, save_steps=2, scaffold=self.scaffold)
736      hook.begin()
737      self.scaffold.finalize()
738      with session_lib.Session() as sess:
739        sess.run(self.scaffold.init_op)
740        mon_sess = monitored_session._HookedSession(sess, [hook])
741        hook.after_create_session(sess, None)
742        mon_sess.run(self.train_op)
743      summary_writer.assert_summaries(
744          test_case=self,
745          expected_logdir=self.model_dir,
746          expected_added_meta_graphs=[
747              meta_graph.create_meta_graph_def(
748                  graph_def=self.graph.as_graph_def(add_shapes=True),
749                  saver_def=self.scaffold.saver.saver_def)
750          ])
751
752    fake_summary_writer.FakeSummaryWriter.uninstall()
753
754  def test_save_checkpoint_before_first_train_step(self):
755    with self.graph.as_default():
756      hook = basic_session_run_hooks.CheckpointSaverHook(
757          self.model_dir, save_steps=2, scaffold=self.scaffold)
758      hook.begin()
759      self.scaffold.finalize()
760      with session_lib.Session() as sess:
761        mon_sess = monitored_session._HookedSession(sess, [hook])
762        sess.run(self.scaffold.init_op)
763        hook.after_create_session(sess, None)
764        # Verifies that checkpoint is saved at step 0.
765        self.assertEqual(0,
766                         checkpoint_utils.load_variable(self.model_dir,
767                                                        self.global_step.name))
768        # Verifies that no checkpoint is saved after one training step.
769        mon_sess.run(self.train_op)
770        self.assertEqual(0,
771                         checkpoint_utils.load_variable(self.model_dir,
772                                                        self.global_step.name))
773        # Verifies that checkpoint is saved after save_steps.
774        mon_sess.run(self.train_op)
775        self.assertEqual(2,
776                         checkpoint_utils.load_variable(self.model_dir,
777                                                        self.global_step.name))
778
779  def test_save_graph_def(self):
780    with self.graph.as_default():
781      hook = basic_session_run_hooks.CheckpointSaverHook(
782          self.model_dir, save_steps=1, scaffold=self.scaffold,
783          save_graph_def=True)
784      hook.begin()
785      self.scaffold.finalize()
786      with session_lib.Session() as sess:
787        sess.run(self.scaffold.init_op)
788        mon_sess = monitored_session._HookedSession(sess, [hook])
789        sess.run(self.scaffold.init_op)
790        hook.after_create_session(sess, None)
791
792        self.assertIn('graph.pbtxt', os.listdir(self.model_dir))
793        # Should have a single .meta file for step 0
794        self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 1)
795
796        mon_sess.run(self.train_op)
797        self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 2)
798
799  def test_save_graph_def_false(self):
800    with self.graph.as_default():
801      hook = basic_session_run_hooks.CheckpointSaverHook(
802          self.model_dir, save_steps=1, scaffold=self.scaffold,
803          save_graph_def=False)
804      hook.begin()
805      self.scaffold.finalize()
806      with session_lib.Session() as sess:
807        sess.run(self.scaffold.init_op)
808        mon_sess = monitored_session._HookedSession(sess, [hook])
809        sess.run(self.scaffold.init_op)
810        hook.after_create_session(sess, None)
811
812        self.assertNotIn('graph.pbtxt', os.listdir(self.model_dir))
813        # Should have a single .meta file for step 0
814        self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta')))
815
816        mon_sess.run(self.train_op)
817        self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta')))
818
819
820
821
822class CheckpointSaverHookMultiStepTest(test.TestCase):
823
824  def setUp(self):
825    self.model_dir = tempfile.mkdtemp()
826    self.graph = ops.Graph()
827    self.steps_per_run = 5
828    with self.graph.as_default():
829      self.scaffold = monitored_session.Scaffold()
830      self.global_step = training_util.get_or_create_global_step()
831      self.train_op = training_util._increment_global_step(self.steps_per_run)
832
833  def tearDown(self):
834    shutil.rmtree(self.model_dir, ignore_errors=True)
835
836  def test_save_steps_saves_in_first_step(self):
837    with self.graph.as_default():
838      hook = basic_session_run_hooks.CheckpointSaverHook(
839          self.model_dir,
840          save_steps=2*self.steps_per_run,
841          scaffold=self.scaffold)
842      hook._set_steps_per_run(self.steps_per_run)
843      hook.begin()
844      self.scaffold.finalize()
845      with session_lib.Session() as sess:
846        sess.run(self.scaffold.init_op)
847        mon_sess = monitored_session._HookedSession(sess, [hook])
848        mon_sess.run(self.train_op)
849        self.assertEqual(5,
850                         checkpoint_utils.load_variable(self.model_dir,
851                                                        self.global_step.name))
852
853  def test_save_steps_saves_periodically(self):
854    with self.graph.as_default():
855      hook = basic_session_run_hooks.CheckpointSaverHook(
856          self.model_dir,
857          save_steps=2*self.steps_per_run,
858          scaffold=self.scaffold)
859      hook._set_steps_per_run(self.steps_per_run)
860      hook.begin()
861      self.scaffold.finalize()
862      with session_lib.Session() as sess:
863        sess.run(self.scaffold.init_op)
864        mon_sess = monitored_session._HookedSession(sess, [hook])
865        mon_sess.run(self.train_op)
866        # Saved (step=5)
867        self.assertEqual(5,
868                         checkpoint_utils.load_variable(self.model_dir,
869                                                        self.global_step.name))
870
871        mon_sess.run(self.train_op)
872        # Not saved (step=10)
873        self.assertEqual(5,
874                         checkpoint_utils.load_variable(self.model_dir,
875                                                        self.global_step.name))
876
877        mon_sess.run(self.train_op)
878        # Saved (step=15)
879        self.assertEqual(15,
880                         checkpoint_utils.load_variable(self.model_dir,
881                                                        self.global_step.name))
882
883        mon_sess.run(self.train_op)
884        # Not saved (step=20)
885        self.assertEqual(15,
886                         checkpoint_utils.load_variable(self.model_dir,
887                                                        self.global_step.name))
888
889        mon_sess.run(self.train_op)
890        # Saved (step=25)
891        self.assertEqual(25,
892                         checkpoint_utils.load_variable(self.model_dir,
893                                                        self.global_step.name))
894
895  def test_save_steps_saves_at_end(self):
896    with self.graph.as_default():
897      hook = basic_session_run_hooks.CheckpointSaverHook(
898          self.model_dir,
899          save_steps=2*self.steps_per_run,
900          scaffold=self.scaffold)
901      hook._set_steps_per_run(self.steps_per_run)
902      hook.begin()
903      self.scaffold.finalize()
904      with session_lib.Session() as sess:
905        sess.run(self.scaffold.init_op)
906        mon_sess = monitored_session._HookedSession(sess, [hook])
907        mon_sess.run(self.train_op)
908        mon_sess.run(self.train_op)
909        hook.end(sess)
910        self.assertEqual(10,
911                         checkpoint_utils.load_variable(self.model_dir,
912                                                        self.global_step.name))
913
914
915class ResourceCheckpointSaverHookTest(test.TestCase):
916
917  def setUp(self):
918    self.model_dir = tempfile.mkdtemp()
919    self.graph = ops.Graph()
920    with self.graph.as_default():
921      self.scaffold = monitored_session.Scaffold()
922      with variable_scope.variable_scope('foo', use_resource=True):
923        self.global_step = training_util.get_or_create_global_step()
924      self.train_op = training_util._increment_global_step(1)
925
926  def test_save_steps_saves_periodically(self):
927    with self.graph.as_default():
928      hook = basic_session_run_hooks.CheckpointSaverHook(
929          self.model_dir, save_steps=2, scaffold=self.scaffold)
930      hook.begin()
931      self.scaffold.finalize()
932      with session_lib.Session() as sess:
933        sess.run(self.scaffold.init_op)
934        mon_sess = monitored_session._HookedSession(sess, [hook])
935        mon_sess.run(self.train_op)
936        mon_sess.run(self.train_op)
937        # Not saved
938        self.assertEqual(1,
939                         checkpoint_utils.load_variable(self.model_dir,
940                                                        self.global_step.name))
941        mon_sess.run(self.train_op)
942        # saved
943        self.assertEqual(3,
944                         checkpoint_utils.load_variable(self.model_dir,
945                                                        self.global_step.name))
946        mon_sess.run(self.train_op)
947        # Not saved
948        self.assertEqual(3,
949                         checkpoint_utils.load_variable(self.model_dir,
950                                                        self.global_step.name))
951        mon_sess.run(self.train_op)
952        # saved
953        self.assertEqual(5,
954                         checkpoint_utils.load_variable(self.model_dir,
955                                                        self.global_step.name))
956
957
958class StepCounterHookTest(test.TestCase):
959
960  def setUp(self):
961    self.log_dir = tempfile.mkdtemp()
962
963  def tearDown(self):
964    shutil.rmtree(self.log_dir, ignore_errors=True)
965
966  @test.mock.patch.object(time, 'time')
967  def test_step_counter_every_n_steps(self, mock_time):
968    mock_time.return_value = MOCK_START_TIME
969    with ops.Graph().as_default() as g, session_lib.Session() as sess:
970      training_util.get_or_create_global_step()
971      train_op = training_util._increment_global_step(1)
972      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
973      hook = basic_session_run_hooks.StepCounterHook(
974          summary_writer=summary_writer, every_n_steps=10)
975      hook.begin()
976      self.evaluate(variables_lib.global_variables_initializer())
977      mon_sess = monitored_session._HookedSession(sess, [hook])
978      with test.mock.patch.object(tf_logging, 'warning') as mock_log:
979        for _ in range(30):
980          mock_time.return_value += 0.01
981          mon_sess.run(train_op)
982        # logging.warning should not be called.
983        self.assertIsNone(mock_log.call_args)
984      hook.end(sess)
985      summary_writer.assert_summaries(
986          test_case=self,
987          expected_logdir=self.log_dir,
988          expected_graph=g,
989          expected_summaries={})
990      self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
991      for step in [11, 21]:
992        summary_value = summary_writer.summaries[step][0].value[0]
993        self.assertEqual('global_step/sec', summary_value.tag)
994        self.assertGreater(summary_value.simple_value, 0)
995
996  @test.mock.patch.object(time, 'time')
997  def test_step_counter_every_n_secs(self, mock_time):
998    mock_time.return_value = MOCK_START_TIME
999    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1000      training_util.get_or_create_global_step()
1001      train_op = training_util._increment_global_step(1)
1002      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
1003      hook = basic_session_run_hooks.StepCounterHook(
1004          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)
1005
1006      hook.begin()
1007      self.evaluate(variables_lib.global_variables_initializer())
1008      mon_sess = monitored_session._HookedSession(sess, [hook])
1009      mon_sess.run(train_op)
1010      mock_time.return_value += 0.2
1011      mon_sess.run(train_op)
1012      mock_time.return_value += 0.2
1013      mon_sess.run(train_op)
1014      hook.end(sess)
1015
1016      summary_writer.assert_summaries(
1017          test_case=self,
1018          expected_logdir=self.log_dir,
1019          expected_graph=g,
1020          expected_summaries={})
1021      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
1022      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
1023      for summary in summary_writer.summaries.values():
1024        summary_value = summary[0].value[0]
1025        self.assertEqual('global_step/sec', summary_value.tag)
1026        self.assertGreater(summary_value.simple_value, 0)
1027
1028  def test_global_step_name(self):
1029    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1030      with variable_scope.variable_scope('bar'):
1031        variable_scope.get_variable(
1032            'foo',
1033            initializer=0,
1034            trainable=False,
1035            collections=[
1036                ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
1037            ])
1038      train_op = training_util._increment_global_step(1)
1039      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
1040      hook = basic_session_run_hooks.StepCounterHook(
1041          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)
1042
1043      hook.begin()
1044      self.evaluate(variables_lib.global_variables_initializer())
1045      mon_sess = monitored_session._HookedSession(sess, [hook])
1046      mon_sess.run(train_op)
1047      mon_sess.run(train_op)
1048      hook.end(sess)
1049
1050      summary_writer.assert_summaries(
1051          test_case=self,
1052          expected_logdir=self.log_dir,
1053          expected_graph=g,
1054          expected_summaries={})
1055      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
1056      self.assertItemsEqual([2], summary_writer.summaries.keys())
1057      summary_value = summary_writer.summaries[2][0].value[0]
1058      self.assertEqual('bar/foo/sec', summary_value.tag)
1059
1060  def test_log_warning_if_global_step_not_increased(self):
1061    with ops.Graph().as_default(), session_lib.Session() as sess:
1062      training_util.get_or_create_global_step()
1063      train_op = training_util._increment_global_step(0)  # keep same.
1064      self.evaluate(variables_lib.global_variables_initializer())
1065      hook = basic_session_run_hooks.StepCounterHook(
1066          every_n_steps=1, every_n_secs=None)
1067      hook.begin()
1068      mon_sess = monitored_session._HookedSession(sess, [hook])
1069      mon_sess.run(train_op)  # Run one step to record global step.
1070      with test.mock.patch.object(tf_logging, 'log_first_n') as mock_log:
1071        for _ in range(30):
1072          mon_sess.run(train_op)
1073        self.assertRegex(
1074            str(mock_log.call_args), 'global step.*has not been increased')
1075      hook.end(sess)
1076
1077  def _setup_steps_per_run_test(self,
1078                                every_n_steps,
1079                                steps_per_run,
1080                                graph,
1081                                sess):
1082    training_util.get_or_create_global_step()
1083    self.train_op = training_util._increment_global_step(steps_per_run)
1084    self.summary_writer = fake_summary_writer.FakeSummaryWriter(
1085        self.log_dir, graph)
1086    self.hook = basic_session_run_hooks.StepCounterHook(
1087        summary_writer=self.summary_writer, every_n_steps=every_n_steps)
1088    self.hook._set_steps_per_run(steps_per_run)
1089    self.hook.begin()
1090    self.evaluate(variables_lib.global_variables_initializer())
1091    self.mon_sess = monitored_session._HookedSession(sess, [self.hook])
1092
1093  @test.mock.patch.object(time, 'time')
1094  def test_steps_per_run_less_than_every_n_steps(self, mock_time):
1095    mock_time.return_value = MOCK_START_TIME
1096    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1097      self._setup_steps_per_run_test(10, 5, g, sess)
1098
1099      # Logs at 15, 25
1100      for _ in range(5):
1101        mock_time.return_value += 0.01
1102        self.mon_sess.run(self.train_op)
1103
1104      self.hook.end(sess)
1105      self.summary_writer.assert_summaries(
1106          test_case=self,
1107          expected_logdir=self.log_dir,
1108          expected_graph=g,
1109          expected_summaries={})
1110      self.assertItemsEqual([15, 25], self.summary_writer.summaries.keys())
1111      for step in [15, 25]:
1112        summary_value = self.summary_writer.summaries[step][0].value[0]
1113        self.assertEqual('global_step/sec', summary_value.tag)
1114        self.assertGreater(summary_value.simple_value, 0)
1115
1116  @test.mock.patch.object(time, 'time')
1117  def test_steps_per_run_equal_every_n_steps(self, mock_time):
1118    mock_time.return_value = MOCK_START_TIME
1119    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1120      self._setup_steps_per_run_test(5, 5, g, sess)
1121
1122      # Logs at 10, 15, 20, 25
1123      for _ in range(5):
1124        mock_time.return_value += 0.01
1125        self.mon_sess.run(self.train_op)
1126
1127      self.hook.end(sess)
1128      self.summary_writer.assert_summaries(
1129          test_case=self,
1130          expected_logdir=self.log_dir,
1131          expected_graph=g,
1132          expected_summaries={})
1133      self.assertItemsEqual([10, 15, 20, 25],
1134                            self.summary_writer.summaries.keys())
1135      for step in [10, 15, 20, 25]:
1136        summary_value = self.summary_writer.summaries[step][0].value[0]
1137        self.assertEqual('global_step/sec', summary_value.tag)
1138        self.assertGreater(summary_value.simple_value, 0)
1139
1140  @test.mock.patch.object(time, 'time')
1141  def test_steps_per_run_greater_than_every_n_steps(self, mock_time):
1142    mock_time.return_value = MOCK_START_TIME
1143    with ops.Graph().as_default() as g, session_lib.Session() as sess:
1144      self._setup_steps_per_run_test(5, 10, g, sess)
1145
1146      # Logs at 20, 30, 40, 50
1147      for _ in range(5):
1148        mock_time.return_value += 0.01
1149        self.mon_sess.run(self.train_op)
1150
1151      self.hook.end(sess)
1152      self.summary_writer.assert_summaries(
1153          test_case=self,
1154          expected_logdir=self.log_dir,
1155          expected_graph=g,
1156          expected_summaries={})
1157      self.assertItemsEqual([20, 30, 40, 50],
1158                            self.summary_writer.summaries.keys())
1159      for step in [20, 30, 40, 50]:
1160        summary_value = self.summary_writer.summaries[step][0].value[0]
1161        self.assertEqual('global_step/sec', summary_value.tag)
1162        self.assertGreater(summary_value.simple_value, 0)
1163
1164
1165@test_util.run_deprecated_v1
1166class SummarySaverHookTest(test.TestCase):
1167
1168  def setUp(self):
1169    test.TestCase.setUp(self)
1170
1171    self.log_dir = 'log/dir'
1172    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
1173
1174    var = variables_lib.Variable(0.0)
1175    tensor = state_ops.assign_add(var, 1.0)
1176    tensor2 = tensor * 2
1177    self.summary_op = summary_lib.scalar('my_summary', tensor)
1178    self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
1179
1180    training_util.get_or_create_global_step()
1181    self.train_op = training_util._increment_global_step(1)
1182
1183  def test_raise_when_scaffold_and_summary_op_both_missing(self):
1184    with self.assertRaises(ValueError):
1185      basic_session_run_hooks.SummarySaverHook()
1186
1187  def test_raise_when_scaffold_and_summary_op_both_present(self):
1188    with self.assertRaises(ValueError):
1189      basic_session_run_hooks.SummarySaverHook(
1190          scaffold=monitored_session.Scaffold(), summary_op=self.summary_op)
1191
1192  def test_raise_in_both_secs_and_steps(self):
1193    with self.assertRaises(ValueError):
1194      basic_session_run_hooks.SummarySaverHook(
1195          save_secs=10, save_steps=20, summary_writer=self.summary_writer)
1196
1197  def test_raise_in_none_secs_and_steps(self):
1198    with self.assertRaises(ValueError):
1199      basic_session_run_hooks.SummarySaverHook(
1200          save_secs=None, save_steps=None, summary_writer=self.summary_writer)
1201
1202  def test_save_steps(self):
1203    hook = basic_session_run_hooks.SummarySaverHook(
1204        save_steps=8,
1205        summary_writer=self.summary_writer,
1206        summary_op=self.summary_op)
1207
1208    with self.cached_session() as sess:
1209      hook.begin()
1210      self.evaluate(variables_lib.global_variables_initializer())
1211      mon_sess = monitored_session._HookedSession(sess, [hook])
1212      for _ in range(30):
1213        mon_sess.run(self.train_op)
1214      hook.end(sess)
1215
1216    self.summary_writer.assert_summaries(
1217        test_case=self,
1218        expected_logdir=self.log_dir,
1219        expected_summaries={
1220            1: {
1221                'my_summary': 1.0
1222            },
1223            9: {
1224                'my_summary': 2.0
1225            },
1226            17: {
1227                'my_summary': 3.0
1228            },
1229            25: {
1230                'my_summary': 4.0
1231            },
1232        })
1233
1234  def test_multiple_summaries(self):
1235    hook = basic_session_run_hooks.SummarySaverHook(
1236        save_steps=8,
1237        summary_writer=self.summary_writer,
1238        summary_op=[self.summary_op, self.summary_op2])
1239
1240    with self.cached_session() as sess:
1241      hook.begin()
1242      self.evaluate(variables_lib.global_variables_initializer())
1243      mon_sess = monitored_session._HookedSession(sess, [hook])
1244      for _ in range(10):
1245        mon_sess.run(self.train_op)
1246      hook.end(sess)
1247
1248    self.summary_writer.assert_summaries(
1249        test_case=self,
1250        expected_logdir=self.log_dir,
1251        expected_summaries={
1252            1: {
1253                'my_summary': 1.0,
1254                'my_summary2': 2.0
1255            },
1256            9: {
1257                'my_summary': 2.0,
1258                'my_summary2': 4.0
1259            },
1260        })
1261
1262  @test.mock.patch.object(time, 'time')
1263  def test_save_secs_saving_once_every_step(self, mock_time):
1264    mock_time.return_value = MOCK_START_TIME
1265    hook = basic_session_run_hooks.SummarySaverHook(
1266        save_secs=0.5,
1267        summary_writer=self.summary_writer,
1268        summary_op=self.summary_op)
1269
1270    with self.cached_session() as sess:
1271      hook.begin()
1272      self.evaluate(variables_lib.global_variables_initializer())
1273      mon_sess = monitored_session._HookedSession(sess, [hook])
1274      for _ in range(4):
1275        mon_sess.run(self.train_op)
1276        mock_time.return_value += 0.5
1277      hook.end(sess)
1278
1279    self.summary_writer.assert_summaries(
1280        test_case=self,
1281        expected_logdir=self.log_dir,
1282        expected_summaries={
1283            1: {
1284                'my_summary': 1.0
1285            },
1286            2: {
1287                'my_summary': 2.0
1288            },
1289            3: {
1290                'my_summary': 3.0
1291            },
1292            4: {
1293                'my_summary': 4.0
1294            },
1295        })
1296
1297  @test.mock.patch.object(time, 'time')
1298  def test_save_secs_saving_once_every_three_steps(self, mock_time):
1299    mock_time.return_value = 1484695987.209386
1300    hook = basic_session_run_hooks.SummarySaverHook(
1301        save_secs=9.,
1302        summary_writer=self.summary_writer,
1303        summary_op=self.summary_op)
1304
1305    with self.cached_session() as sess:
1306      hook.begin()
1307      self.evaluate(variables_lib.global_variables_initializer())
1308      mon_sess = monitored_session._HookedSession(sess, [hook])
1309      for _ in range(8):
1310        mon_sess.run(self.train_op)
1311        mock_time.return_value += 3.1
1312      hook.end(sess)
1313
1314    # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
1315    self.summary_writer.assert_summaries(
1316        test_case=self,
1317        expected_logdir=self.log_dir,
1318        expected_summaries={
1319            1: {
1320                'my_summary': 1.0
1321            },
1322            4: {
1323                'my_summary': 2.0
1324            },
1325            7: {
1326                'my_summary': 3.0
1327            },
1328        })
1329
1330
1331class GlobalStepWaiterHookTest(test.TestCase):
1332
1333  def test_not_wait_for_step_zero(self):
1334    with ops.Graph().as_default():
1335      training_util.get_or_create_global_step()
1336      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
1337      hook.begin()
1338      with session_lib.Session() as sess:
1339        # Before run should return without waiting gstep increment.
1340        hook.before_run(
1341            session_run_hook.SessionRunContext(
1342                original_args=None, session=sess))
1343
1344  @test.mock.patch.object(time, 'sleep')
1345  def test_wait_for_step(self, mock_sleep):
1346    with ops.Graph().as_default():
1347      gstep = training_util.get_or_create_global_step()
1348      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
1349      hook.begin()
1350
1351      with session_lib.Session() as sess:
1352        # Mock out calls to time.sleep() to update the global step.
1353
1354        class Context(object):
1355          counter = 0
1356
1357        def mock_sleep_side_effect(seconds):
1358          del seconds  # argument is ignored
1359          Context.counter += 1
1360          if Context.counter == 1:
1361            # The first time sleep() is called, we update the global_step from
1362            # 0 to 500.
1363            sess.run(state_ops.assign(gstep, 500))
1364          elif Context.counter == 2:
1365            # The second time sleep() is called, we update the global_step from
1366            # 500 to 1100.
1367            sess.run(state_ops.assign(gstep, 1100))
1368          else:
1369            raise AssertionError(
1370                'Expected before_run() to terminate after the second call to '
1371                'time.sleep()')
1372
1373        mock_sleep.side_effect = mock_sleep_side_effect
1374
1375        # Run the mocked-out interaction with the hook.
1376        self.evaluate(variables_lib.global_variables_initializer())
1377        run_context = session_run_hook.SessionRunContext(
1378            original_args=None, session=sess)
1379        hook.before_run(run_context)
1380        self.assertEqual(Context.counter, 2)
1381
1382
1383class FinalOpsHookTest(test.TestCase):
1384
1385  def test_final_ops_is_scalar_tensor(self):
1386    with ops.Graph().as_default():
1387      expected_value = 4
1388      final_ops = constant_op.constant(expected_value)
1389
1390      hook = basic_session_run_hooks.FinalOpsHook(final_ops)
1391      hook.begin()
1392
1393      with session_lib.Session() as session:
1394        hook.end(session)
1395        self.assertEqual(expected_value,
1396                         hook.final_ops_values)
1397
1398  def test_final_ops_is_tensor(self):
1399    with ops.Graph().as_default():
1400      expected_values = [1, 6, 3, 5, 2, 4]
1401      final_ops = constant_op.constant(expected_values)
1402
1403      hook = basic_session_run_hooks.FinalOpsHook(final_ops)
1404      hook.begin()
1405
1406      with session_lib.Session() as session:
1407        hook.end(session)
1408        self.assertListEqual(expected_values,
1409                             hook.final_ops_values.tolist())
1410
1411  def test_final_ops_triggers_out_of_range_error(self):
1412    with ops.Graph().as_default():
1413      dataset = dataset_ops.Dataset.range(1)
1414      iterator = dataset_ops.make_one_shot_iterator(dataset)
1415      read_ops = iterator.get_next()
1416      final_ops = read_ops
1417
1418      hook = basic_session_run_hooks.FinalOpsHook(final_ops)
1419      hook.begin()
1420
1421      with session_lib.Session() as session:
1422        session.run(read_ops)
1423        with test.mock.patch.object(tf_logging, 'warning') as mock_log:
1424          with self.assertRaisesRegex(errors.OutOfRangeError,
1425                                      'End of sequence'):
1426            hook.end(session)
1427          self.assertRegex(
1428              str(mock_log.call_args), 'dependency back to some input source')
1429
1430  def test_final_ops_with_dictionary(self):
1431    with ops.Graph().as_default():
1432      expected_values = [4, -3]
1433      final_ops = array_ops.placeholder(dtype=dtypes.float32)
1434      final_ops_feed_dict = {final_ops: expected_values}
1435
1436      hook = basic_session_run_hooks.FinalOpsHook(
1437          final_ops, final_ops_feed_dict)
1438      hook.begin()
1439
1440      with session_lib.Session() as session:
1441        hook.end(session)
1442        self.assertListEqual(expected_values,
1443                             hook.final_ops_values.tolist())
1444
1445
1446@test_util.run_deprecated_v1
1447class ResourceSummarySaverHookTest(test.TestCase):
1448
1449  def setUp(self):
1450    test.TestCase.setUp(self)
1451
1452    self.log_dir = 'log/dir'
1453    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
1454
1455    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
1456    tensor = state_ops.assign_add(var, 1.0)
1457    self.summary_op = summary_lib.scalar('my_summary', tensor)
1458
1459    with variable_scope.variable_scope('foo', use_resource=True):
1460      training_util.create_global_step()
1461    self.train_op = training_util._increment_global_step(1)
1462
1463  def test_save_steps(self):
1464    hook = basic_session_run_hooks.SummarySaverHook(
1465        save_steps=8,
1466        summary_writer=self.summary_writer,
1467        summary_op=self.summary_op)
1468
1469    with self.cached_session() as sess:
1470      hook.begin()
1471      self.evaluate(variables_lib.global_variables_initializer())
1472      mon_sess = monitored_session._HookedSession(sess, [hook])
1473      for _ in range(30):
1474        mon_sess.run(self.train_op)
1475      hook.end(sess)
1476
1477    self.summary_writer.assert_summaries(
1478        test_case=self,
1479        expected_logdir=self.log_dir,
1480        expected_summaries={
1481            1: {
1482                'my_summary': 1.0
1483            },
1484            9: {
1485                'my_summary': 2.0
1486            },
1487            17: {
1488                'my_summary': 3.0
1489            },
1490            25: {
1491                'my_summary': 4.0
1492            },
1493        })
1494
1495
1496class FeedFnHookTest(test.TestCase):
1497
1498  def test_feeding_placeholder(self):
1499    with ops.Graph().as_default(), session_lib.Session() as sess:
1500      x = array_ops.placeholder(dtype=dtypes.float32)
1501      y = x + 1
1502      hook = basic_session_run_hooks.FeedFnHook(
1503          feed_fn=lambda: {x: 1.0})
1504      hook.begin()
1505      mon_sess = monitored_session._HookedSession(sess, [hook])
1506      self.assertEqual(mon_sess.run(y), 2)
1507
1508
1509class ProfilerHookTest(test.TestCase):
1510
1511  def setUp(self):
1512    super(ProfilerHookTest, self).setUp()
1513    self.output_dir = tempfile.mkdtemp()
1514    self.graph = ops.Graph()
1515    self.filepattern = os.path.join(self.output_dir, 'timeline-*.json')
1516    with self.graph.as_default():
1517      self.global_step = training_util.get_or_create_global_step()
1518      self.train_op = state_ops.assign_add(self.global_step, 1)
1519
1520  def tearDown(self):
1521    super(ProfilerHookTest, self).tearDown()
1522    shutil.rmtree(self.output_dir, ignore_errors=True)
1523
1524  def _count_timeline_files(self):
1525    return len(gfile.Glob(self.filepattern))
1526
1527  @test_util.run_deprecated_v1
1528  def test_raise_in_both_secs_and_steps(self):
1529    with self.assertRaises(ValueError):
1530      basic_session_run_hooks.ProfilerHook(save_secs=10, save_steps=20)
1531
1532  @test_util.run_deprecated_v1
1533  def test_raise_in_none_secs_and_steps(self):
1534    with self.assertRaises(ValueError):
1535      basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
1536
1537  def test_save_secs_does_not_save_in_first_step(self):
1538    with self.graph.as_default():
1539      hook = basic_session_run_hooks.ProfilerHook(
1540          save_secs=2, output_dir=self.output_dir)
1541      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1542        sess.run(self.train_op)
1543        self.assertEqual(0, self._count_timeline_files())
1544
1545  @test.mock.patch.object(time, 'time')
1546  def test_save_secs_saves_periodically(self, mock_time):
1547    # Pick a fixed start time.
1548    with self.graph.as_default():
1549      mock_time.return_value = MOCK_START_TIME
1550      hook = basic_session_run_hooks.ProfilerHook(
1551          save_secs=2, output_dir=self.output_dir)
1552      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1553        sess.run(self.train_op)  # Not saved.
1554        self.assertEqual(0, self._count_timeline_files())
1555        # Simulate 2.5 seconds of sleep.
1556        mock_time.return_value = MOCK_START_TIME + 2.5
1557        sess.run(self.train_op)  # Saved.
1558        self.assertEqual(1, self._count_timeline_files())
1559
1560        # Pretend some small amount of time has passed.
1561        mock_time.return_value = MOCK_START_TIME + 2.6
1562        sess.run(self.train_op)  # Not saved.
1563        # Edge test just before we should save the timeline.
1564        mock_time.return_value = MOCK_START_TIME + 4.4
1565        sess.run(self.train_op)  # Not saved.
1566        self.assertEqual(1, self._count_timeline_files())
1567
1568        mock_time.return_value = MOCK_START_TIME + 4.5
1569        sess.run(self.train_op)  # Saved.
1570        self.assertEqual(2, self._count_timeline_files())
1571
1572  def test_save_steps_does_not_save_in_first_step(self):
1573    with self.graph.as_default():
1574      hook = basic_session_run_hooks.ProfilerHook(
1575          save_steps=1, output_dir=self.output_dir)
1576      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1577        sess.run(self.train_op)  # Not saved.
1578        self.assertEqual(0, self._count_timeline_files())
1579
1580  def test_save_steps_saves_periodically(self):
1581    with self.graph.as_default():
1582      hook = basic_session_run_hooks.ProfilerHook(
1583          save_steps=2, output_dir=self.output_dir)
1584      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1585        self.assertEqual(0, self._count_timeline_files())
1586        sess.run(self.train_op)  # Not saved.
1587        self.assertEqual(0, self._count_timeline_files())
1588        sess.run(self.train_op)  # Saved.
1589        self.assertEqual(1, self._count_timeline_files())
1590        sess.run(self.train_op)  # Not saved.
1591        self.assertEqual(1, self._count_timeline_files())
1592        sess.run(self.train_op)  # Saved.
1593        self.assertEqual(2, self._count_timeline_files())
1594        sess.run(self.train_op)  # Not saved.
1595        self.assertEqual(2, self._count_timeline_files())
1596
1597  def test_run_metadata_saves(self):
1598    writer_cache.FileWriterCache.clear()
1599    fake_summary_writer.FakeSummaryWriter.install()
1600    fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
1601    with self.graph.as_default():
1602      hook = basic_session_run_hooks.ProfilerHook(
1603          save_steps=1, output_dir=self.output_dir)
1604      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
1605        sess.run(self.train_op)  # Not saved.
1606        sess.run(self.train_op)  # Saved.
1607        self.assertEqual(
1608            list(fake_writer._added_run_metadata.keys()), ['step_2'])
1609    fake_summary_writer.FakeSummaryWriter.uninstall()
1610
1611
1612if __name__ == '__main__':
1613  test.main()
1614