• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for supervisor.py."""
16
17import glob
18import os
19import shutil
20import time
21import uuid
22
23
24from tensorflow.core.framework import graph_pb2
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.core.protobuf import meta_graph_pb2
27from tensorflow.core.util import event_pb2
28from tensorflow.python.checkpoint import checkpoint_management
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors_impl
32from tensorflow.python.framework import meta_graph
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import io_ops
37from tensorflow.python.ops import parsing_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import gfile
40from tensorflow.python.platform import test
41from tensorflow.python.summary import summary
42from tensorflow.python.summary import summary_iterator
43from tensorflow.python.summary.writer import writer
44from tensorflow.python.training import input as input_lib
45from tensorflow.python.training import saver as saver_lib
46from tensorflow.python.training import server_lib
47from tensorflow.python.training import session_manager as session_manager_lib
48from tensorflow.python.training import supervisor
49
50
51def _summary_iterator(test_dir):
52  """Reads events from test_dir/events.
53
54  Args:
55    test_dir: Name of the test directory.
56
57  Returns:
58    A summary_iterator
59  """
60  event_paths = sorted(glob.glob(os.path.join(test_dir, "event*")))
61  return summary_iterator.summary_iterator(event_paths[-1])
62
63
64class SupervisorTest(test.TestCase):
65
66  def _test_dir(self, test_name):
67    test_dir = os.path.join(self.get_temp_dir(), test_name)
68    if os.path.exists(test_dir):
69      shutil.rmtree(test_dir)
70    return test_dir
71
72  def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
73    """Wait for a checkpoint file to appear.
74
75    Args:
76      pattern: A string.
77      timeout_secs: How long to wait for in seconds.
78      for_checkpoint: whether we're globbing for checkpoints.
79    """
80    end_time = time.time() + timeout_secs
81    while time.time() < end_time:
82      if for_checkpoint:
83        if checkpoint_management.checkpoint_exists(pattern):
84          return
85      else:
86        if len(gfile.Glob(pattern)) >= 1:
87          return
88      time.sleep(0.05)
89    self.assertFalse(True, "Glob never matched any file: %s" % pattern)
90
91  # This test does not test much.
92  def testBasics(self):
93    logdir = self._test_dir("basics")
94    with ops.Graph().as_default():
95      my_op = constant_op.constant(1.0)
96      sv = supervisor.Supervisor(logdir=logdir)
97      sess = sv.prepare_or_wait_for_session("")
98      for _ in range(10):
99        self.evaluate(my_op)
100      sess.close()
101      sv.stop()
102
103  def testManagedSession(self):
104    logdir = self._test_dir("managed_session")
105    with ops.Graph().as_default():
106      my_op = constant_op.constant(1.0)
107      sv = supervisor.Supervisor(logdir=logdir)
108      with sv.managed_session(""):
109        for _ in range(10):
110          self.evaluate(my_op)
111      # Supervisor has been stopped.
112      self.assertTrue(sv.should_stop())
113
114  def testManagedSessionUserError(self):
115    logdir = self._test_dir("managed_user_error")
116    with ops.Graph().as_default():
117      my_op = constant_op.constant(1.0)
118      sv = supervisor.Supervisor(logdir=logdir)
119      last_step = None
120      with self.assertRaisesRegex(RuntimeError, "failing here"):
121        with sv.managed_session("") as sess:
122          for step in range(10):
123            last_step = step
124            if step == 1:
125              raise RuntimeError("failing here")
126            else:
127              self.evaluate(my_op)
128      # Supervisor has been stopped.
129      self.assertTrue(sv.should_stop())
130      self.assertEqual(1, last_step)
131
132  def testManagedSessionIgnoreOutOfRangeError(self):
133    logdir = self._test_dir("managed_out_of_range")
134    with ops.Graph().as_default():
135      my_op = constant_op.constant(1.0)
136      sv = supervisor.Supervisor(logdir=logdir)
137      last_step = None
138      with sv.managed_session("") as sess:
139        for step in range(10):
140          last_step = step
141          if step == 3:
142            raise errors_impl.OutOfRangeError(my_op.op.node_def, my_op.op,
143                                              "all done")
144          else:
145            self.evaluate(my_op)
146      # Supervisor has been stopped.  OutOfRangeError was not thrown.
147      self.assertTrue(sv.should_stop())
148      self.assertEqual(3, last_step)
149
150  def testManagedSessionDoNotKeepSummaryWriter(self):
151    logdir = self._test_dir("managed_not_keep_summary_writer")
152    with ops.Graph().as_default():
153      summary.scalar("c1", constant_op.constant(1))
154      summary.scalar("c2", constant_op.constant(2))
155      summary.scalar("c3", constant_op.constant(3))
156      summ = summary.merge_all()
157      sv = supervisor.Supervisor(logdir=logdir, summary_op=None)
158      with sv.managed_session(
159          "", close_summary_writer=True, start_standard_services=False) as sess:
160        sv.summary_computed(sess, sess.run(summ))
161      # Sleep 1.2s to make sure that the next event file has a different name
162      # than the current one.
163      time.sleep(1.2)
164      with sv.managed_session(
165          "", close_summary_writer=True, start_standard_services=False) as sess:
166        sv.summary_computed(sess, sess.run(summ))
167    event_paths = sorted(glob.glob(os.path.join(logdir, "event*")))
168    self.assertEqual(2, len(event_paths))
169    # The two event files should have the same contents.
170    for path in event_paths:
171      # The summary iterator should report the summary once as we closed the
172      # summary writer across the 2 sessions.
173      rr = summary_iterator.summary_iterator(path)
174      # The first event should list the file_version.
175      ev = next(rr)
176      self.assertEqual("brain.Event:2", ev.file_version)
177
178      # The next one has the graph and metagraph.
179      ev = next(rr)
180      self.assertTrue(ev.graph_def)
181
182      ev = next(rr)
183      self.assertTrue(ev.meta_graph_def)
184
185      # The next one should have the values from the summary.
186      # But only once.
187      ev = next(rr)
188      self.assertProtoEquals("""
189        value { tag: 'c1' simple_value: 1.0 }
190        value { tag: 'c2' simple_value: 2.0 }
191        value { tag: 'c3' simple_value: 3.0 }
192        """, ev.summary)
193
194      # The next one should be a stop message if we closed cleanly.
195      ev = next(rr)
196      self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
197
198      # We should be done.
199      with self.assertRaises(StopIteration):
200        next(rr)
201
202  def testManagedSessionKeepSummaryWriter(self):
203    logdir = self._test_dir("managed_keep_summary_writer")
204    with ops.Graph().as_default():
205      summary.scalar("c1", constant_op.constant(1))
206      summary.scalar("c2", constant_op.constant(2))
207      summary.scalar("c3", constant_op.constant(3))
208      summ = summary.merge_all()
209      sv = supervisor.Supervisor(logdir=logdir)
210      with sv.managed_session(
211          "", close_summary_writer=False,
212          start_standard_services=False) as sess:
213        sv.summary_computed(sess, sess.run(summ))
214      with sv.managed_session(
215          "", close_summary_writer=False,
216          start_standard_services=False) as sess:
217        sv.summary_computed(sess, sess.run(summ))
218    # Now close the summary writer to flush the events.
219    sv.summary_writer.close()
220    # The summary iterator should report the summary twice as we reused
221    # the same summary writer across the 2 sessions.
222    rr = _summary_iterator(logdir)
223    # The first event should list the file_version.
224    ev = next(rr)
225    self.assertEqual("brain.Event:2", ev.file_version)
226
227    # The next one has the graph.
228    ev = next(rr)
229    self.assertTrue(ev.graph_def)
230
231    ev = next(rr)
232    self.assertTrue(ev.meta_graph_def)
233
234    # The next one should have the values from the summary.
235    ev = next(rr)
236    self.assertProtoEquals("""
237      value { tag: 'c1' simple_value: 1.0 }
238      value { tag: 'c2' simple_value: 2.0 }
239      value { tag: 'c3' simple_value: 3.0 }
240      """, ev.summary)
241
242    # The next one should also have the values from the summary.
243    ev = next(rr)
244    self.assertProtoEquals("""
245      value { tag: 'c1' simple_value: 1.0 }
246      value { tag: 'c2' simple_value: 2.0 }
247      value { tag: 'c3' simple_value: 3.0 }
248      """, ev.summary)
249
250    # We should be done.
251    self.assertRaises(StopIteration, lambda: next(rr))
252
253  def _csv_data(self, logdir):
254    # Create a small data file with 3 CSV records.
255    data_path = os.path.join(logdir, "data.csv")
256    with open(data_path, "w") as f:
257      f.write("1,2,3\n")
258      f.write("4,5,6\n")
259      f.write("7,8,9\n")
260    return data_path
261
262  def testManagedEndOfInputOneQueue(self):
263    # Tests that the supervisor finishes without an error when using
264    # a fixed number of epochs, reading from a single queue.
265    logdir = self._test_dir("managed_end_of_input_one_queue")
266    os.makedirs(logdir)
267    data_path = self._csv_data(logdir)
268    with ops.Graph().as_default():
269      # Create an input pipeline that reads the file 3 times.
270      filename_queue = input_lib.string_input_producer(
271          [data_path], num_epochs=3)
272      reader = io_ops.TextLineReader()
273      _, csv = reader.read(filename_queue)
274      rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
275      sv = supervisor.Supervisor(logdir=logdir)
276      with sv.managed_session("") as sess:
277        while not sv.should_stop():
278          sess.run(rec)
279
280  def testManagedEndOfInputTwoQueues(self):
281    # Tests that the supervisor finishes without an error when using
282    # a fixed number of epochs, reading from two queues, the second
283    # one producing a batch from the first one.
284    logdir = self._test_dir("managed_end_of_input_two_queues")
285    os.makedirs(logdir)
286    data_path = self._csv_data(logdir)
287    with ops.Graph().as_default():
288      # Create an input pipeline that reads the file 3 times.
289      filename_queue = input_lib.string_input_producer(
290          [data_path], num_epochs=3)
291      reader = io_ops.TextLineReader()
292      _, csv = reader.read(filename_queue)
293      rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
294      shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
295      sv = supervisor.Supervisor(logdir=logdir)
296      with sv.managed_session("") as sess:
297        while not sv.should_stop():
298          sess.run(shuff_rec)
299
300  def testManagedMainErrorTwoQueues(self):
301    # Tests that the supervisor correctly raises a main loop
302    # error even when using multiple queues for input.
303    logdir = self._test_dir("managed_main_error_two_queues")
304    os.makedirs(logdir)
305    data_path = self._csv_data(logdir)
306    with self.assertRaisesRegex(RuntimeError, "fail at step 3"):
307      with ops.Graph().as_default():
308        # Create an input pipeline that reads the file 3 times.
309        filename_queue = input_lib.string_input_producer(
310            [data_path], num_epochs=3)
311        reader = io_ops.TextLineReader()
312        _, csv = reader.read(filename_queue)
313        rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
314        shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
315        sv = supervisor.Supervisor(logdir=logdir)
316        with sv.managed_session("") as sess:
317          for step in range(9):
318            if sv.should_stop():
319              break
320            elif step == 3:
321              raise RuntimeError("fail at step 3")
322            else:
323              sess.run(shuff_rec)
324
325  def testSessionConfig(self):
326    logdir = self._test_dir("session_config")
327    with ops.Graph().as_default():
328      with ops.device("/cpu:1"):
329        my_op = constant_op.constant([1.0])
330      sv = supervisor.Supervisor(logdir=logdir)
331      sess = sv.prepare_or_wait_for_session(
332          "", config=config_pb2.ConfigProto(device_count={"CPU": 2}))
333      for _ in range(10):
334        self.evaluate(my_op)
335      sess.close()
336      sv.stop()
337
338  def testChiefCanWriteEvents(self):
339    logdir = self._test_dir("can_write")
340    with ops.Graph().as_default():
341      summary.scalar("c1", constant_op.constant(1))
342      summary.scalar("c2", constant_op.constant(2))
343      summary.scalar("c3", constant_op.constant(3))
344      summ = summary.merge_all()
345      sv = supervisor.Supervisor(is_chief=True, logdir=logdir, summary_op=None)
346      meta_graph_def = meta_graph.create_meta_graph_def()
347      sess = sv.prepare_or_wait_for_session("")
348      sv.summary_computed(sess, sess.run(summ))
349      sess.close()
350      # Wait to make sure everything is written to file before stopping.
351      time.sleep(1)
352      sv.stop()
353
354    rr = _summary_iterator(logdir)
355
356    # The first event should list the file_version.
357    ev = next(rr)
358    self.assertEqual("brain.Event:2", ev.file_version)
359
360    # The next one has the graph.
361    ev = next(rr)
362    ev_graph = graph_pb2.GraphDef()
363    ev_graph.ParseFromString(ev.graph_def)
364    self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
365
366    # Stored MetaGraphDef
367    ev = next(rr)
368    ev_meta_graph = meta_graph_pb2.MetaGraphDef()
369    ev_meta_graph.ParseFromString(ev.meta_graph_def)
370    self.assertProtoEquals(meta_graph_def, ev_meta_graph)
371    self.assertProtoEquals(
372        sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
373    # The next one should have the values from the summary.
374    ev = next(rr)
375    self.assertProtoEquals("""
376      value { tag: 'c1' simple_value: 1.0 }
377      value { tag: 'c2' simple_value: 2.0 }
378      value { tag: 'c3' simple_value: 3.0 }
379      """, ev.summary)
380
381    # The next one should be a stop message if we closed cleanly.
382    ev = next(rr)
383    self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
384
385    # We should be done.
386    self.assertRaises(StopIteration, lambda: next(rr))
387
388  def testNonChiefCannotWriteEvents(self):
389
390    def _summary_computed():
391      with ops.Graph().as_default():
392        sv = supervisor.Supervisor(is_chief=False)
393        sess = sv.prepare_or_wait_for_session("")
394        summary.scalar("c1", constant_op.constant(1))
395        summary.scalar("c2", constant_op.constant(2))
396        summ = summary.merge_all()
397        sv.summary_computed(sess, sess.run(summ))
398
399    def _start_standard_services():
400      with ops.Graph().as_default():
401        sv = supervisor.Supervisor(is_chief=False)
402        sess = sv.prepare_or_wait_for_session("")
403        sv.start_standard_services(sess)
404
405    self.assertRaises(RuntimeError, _summary_computed)
406    self.assertRaises(RuntimeError, _start_standard_services)
407
408  def testNoLogdirButWantSummary(self):
409    with ops.Graph().as_default():
410      summary.scalar("c1", constant_op.constant(1))
411      summary.scalar("c2", constant_op.constant(2))
412      summary.scalar("c3", constant_op.constant(3))
413      summ = summary.merge_all()
414      sv = supervisor.Supervisor(logdir="", summary_op=None)
415      sess = sv.prepare_or_wait_for_session("")
416      with self.assertRaisesRegex(RuntimeError, "requires a summary writer"):
417        sv.summary_computed(sess, sess.run(summ))
418
419  @test_util.run_v1_only("train.Supervisor is for v1 only")
420  def testLogdirButExplicitlyNoSummaryWriter(self):
421    logdir = self._test_dir("explicit_no_summary_writer")
422    with ops.Graph().as_default():
423      variables.VariableV1([1.0], name="foo")
424      summary.scalar("c1", constant_op.constant(1))
425      summary.scalar("c2", constant_op.constant(2))
426      summary.scalar("c3", constant_op.constant(3))
427      summ = summary.merge_all()
428      sv = supervisor.Supervisor(logdir=logdir, summary_writer=None)
429      sess = sv.prepare_or_wait_for_session("")
430      # Check that a checkpoint is still be generated.
431      self._wait_for_glob(sv.save_path, 3.0)
432      # Check that we cannot write a summary
433      with self.assertRaisesRegex(RuntimeError, "requires a summary writer"):
434        sv.summary_computed(sess, sess.run(summ))
435
436  def testNoLogdirButExplicitSummaryWriter(self):
437    logdir = self._test_dir("explicit_summary_writer")
438    with ops.Graph().as_default():
439      summary.scalar("c1", constant_op.constant(1))
440      summary.scalar("c2", constant_op.constant(2))
441      summary.scalar("c3", constant_op.constant(3))
442      summ = summary.merge_all()
443      sw = writer.FileWriter(logdir)
444      sv = supervisor.Supervisor(logdir="", summary_op=None, summary_writer=sw)
445      meta_graph_def = meta_graph.create_meta_graph_def()
446      sess = sv.prepare_or_wait_for_session("")
447      sv.summary_computed(sess, sess.run(summ))
448      sess.close()
449      # Wait to make sure everything is written to file before stopping.
450      time.sleep(1)
451      sv.stop()
452
453    # Check the summary was written to 'logdir'
454    rr = _summary_iterator(logdir)
455
456    # The first event should list the file_version.
457    ev = next(rr)
458    self.assertEqual("brain.Event:2", ev.file_version)
459
460    # The next one has the graph.
461    ev = next(rr)
462    ev_graph = graph_pb2.GraphDef()
463    ev_graph.ParseFromString(ev.graph_def)
464    self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
465
466    # Stored MetaGraphDef
467    ev = next(rr)
468    ev_meta_graph = meta_graph_pb2.MetaGraphDef()
469    ev_meta_graph.ParseFromString(ev.meta_graph_def)
470    self.assertProtoEquals(meta_graph_def, ev_meta_graph)
471    self.assertProtoEquals(
472        sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
473
474    # The next one should have the values from the summary.
475    ev = next(rr)
476    self.assertProtoEquals("""
477      value { tag: 'c1' simple_value: 1.0 }
478      value { tag: 'c2' simple_value: 2.0 }
479      value { tag: 'c3' simple_value: 3.0 }
480      """, ev.summary)
481
482    # The next one should be a stop message if we closed cleanly.
483    ev = next(rr)
484    self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
485
486    # We should be done.
487    self.assertRaises(StopIteration, lambda: next(rr))
488
489  def testNoLogdirSucceeds(self):
490    with ops.Graph().as_default():
491      variables.VariableV1([1.0, 2.0, 3.0])
492      sv = supervisor.Supervisor(logdir="", summary_op=None)
493      sess = sv.prepare_or_wait_for_session("")
494      sess.close()
495      sv.stop()
496
497  def testUseSessionManager(self):
498    with ops.Graph().as_default():
499      variables.VariableV1([1.0, 2.0, 3.0])
500      sm = session_manager_lib.SessionManager()
501      # Pass in session_manager. The additional init_op is ignored.
502      sv = supervisor.Supervisor(logdir="", session_manager=sm)
503      sv.prepare_or_wait_for_session("")
504
505  @test_util.run_v1_only("train.Supervisor is for v1 only")
506  def testInitOp(self):
507    logdir = self._test_dir("default_init_op")
508    with ops.Graph().as_default():
509      v = variables.VariableV1([1.0, 2.0, 3.0])
510      sv = supervisor.Supervisor(logdir=logdir)
511      sess = sv.prepare_or_wait_for_session("")
512      self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
513      sv.stop()
514
515  @test_util.run_v1_only("train.Supervisor is for v1 only")
516  def testInitFn(self):
517    logdir = self._test_dir("default_init_op")
518    with ops.Graph().as_default():
519      v = variables.VariableV1([1.0, 2.0, 3.0])
520
521      def _init_fn(sess):
522        sess.run(v.initializer)
523
524      sv = supervisor.Supervisor(logdir=logdir, init_op=None, init_fn=_init_fn)
525      sess = sv.prepare_or_wait_for_session("")
526      self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
527      sv.stop()
528
529  @test_util.run_v1_only("train.Supervisor is for v1 only")
530  def testInitOpWithFeedDict(self):
531    logdir = self._test_dir("feed_dict_init_op")
532    with ops.Graph().as_default():
533      p = array_ops.placeholder(dtypes.float32, shape=(3,))
534      v = variables.VariableV1(p, name="v")
535      sv = supervisor.Supervisor(
536          logdir=logdir,
537          init_op=variables.global_variables_initializer(),
538          init_feed_dict={p: [1.0, 2.0, 3.0]})
539      sess = sv.prepare_or_wait_for_session("")
540      self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
541      sv.stop()
542
543  @test_util.run_v1_only("train.Supervisor is for v1 only")
544  def testReadyForLocalInitOp(self):
545    server = server_lib.Server.create_local_server()
546    logdir = self._test_dir("default_ready_for_local_init_op")
547
548    uid = uuid.uuid4().hex
549
550    def get_session(is_chief):
551      g = ops.Graph()
552      with g.as_default():
553        with ops.device("/job:localhost"):
554          v = variables.VariableV1(
555              1, name="default_ready_for_local_init_op_v_" + str(uid))
556          vadd = v.assign_add(1)
557          w = variables.VariableV1(
558              v,
559              trainable=False,
560              collections=[ops.GraphKeys.LOCAL_VARIABLES],
561              name="default_ready_for_local_init_op_w_" + str(uid))
562          ready_for_local_init_op = variables.report_uninitialized_variables(
563              variables.global_variables())
564      sv = supervisor.Supervisor(
565          logdir=logdir,
566          is_chief=is_chief,
567          graph=g,
568          recovery_wait_secs=1,
569          init_op=v.initializer,
570          ready_for_local_init_op=ready_for_local_init_op)
571      sess = sv.prepare_or_wait_for_session(server.target)
572
573      return sv, sess, v, vadd, w
574
575    sv0, sess0, v0, _, w0 = get_session(True)
576    sv1, sess1, _, vadd1, w1 = get_session(False)
577
578    self.assertEqual(1, sess0.run(w0))
579    self.assertEqual(2, sess1.run(vadd1))
580    self.assertEqual(1, sess1.run(w1))
581    self.assertEqual(2, sess0.run(v0))
582
583    sv0.stop()
584    sv1.stop()
585
586  @test_util.run_v1_only("train.Supervisor is for v1 only")
587  def testReadyForLocalInitOpRestoreFromCheckpoint(self):
588    server = server_lib.Server.create_local_server()
589    logdir = self._test_dir("ready_for_local_init_op_restore")
590
591    uid = uuid.uuid4().hex
592
593    # Create a checkpoint.
594    with ops.Graph().as_default():
595      v = variables.VariableV1(
596          10.0, name="ready_for_local_init_op_restore_v_" + str(uid))
597      summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v)
598      sv = supervisor.Supervisor(logdir=logdir)
599      sv.prepare_or_wait_for_session(server.target)
600      save_path = sv.save_path
601      self._wait_for_glob(save_path, 3.0)
602      self._wait_for_glob(
603          os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
604      # Wait to make sure everything is written to file before stopping.
605      time.sleep(1)
606      sv.stop()
607
608    def get_session(is_chief):
609      g = ops.Graph()
610      with g.as_default():
611        with ops.device("/job:localhost"):
612          v = variables.VariableV1(
613              1.0, name="ready_for_local_init_op_restore_v_" + str(uid))
614          vadd = v.assign_add(1)
615          w = variables.VariableV1(
616              v,
617              trainable=False,
618              collections=[ops.GraphKeys.LOCAL_VARIABLES],
619              name="ready_for_local_init_op_restore_w_" + str(uid))
620          ready_for_local_init_op = variables.report_uninitialized_variables(
621              variables.global_variables())
622      sv = supervisor.Supervisor(
623          logdir=logdir,
624          is_chief=is_chief,
625          graph=g,
626          recovery_wait_secs=1,
627          ready_for_local_init_op=ready_for_local_init_op)
628      sess = sv.prepare_or_wait_for_session(server.target)
629
630      return sv, sess, v, vadd, w
631
632    sv0, sess0, v0, _, w0 = get_session(True)
633    sv1, sess1, _, vadd1, w1 = get_session(False)
634
635    self.assertEqual(10, sess0.run(w0))
636    self.assertEqual(11, sess1.run(vadd1))
637    self.assertEqual(10, sess1.run(w1))
638    self.assertEqual(11, sess0.run(v0))
639
640    sv0.stop()
641    sv1.stop()
642
643  def testLocalInitOp(self):
644    logdir = self._test_dir("default_local_init_op")
645    with ops.Graph().as_default():
646      # A local variable.
647      v = variables.VariableV1(
648          [1.0, 2.0, 3.0],
649          trainable=False,
650          collections=[ops.GraphKeys.LOCAL_VARIABLES])
651
652      # An entity which is initialized through a TABLE_INITIALIZER.
653      w = variables.VariableV1([4, 5, 6], trainable=False, collections=[])
654      ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer)
655
656      # This shouldn't add a variable to the VARIABLES collection responsible
657      # for variables that are saved/restored from checkpoints.
658      self.assertEqual(len(variables.global_variables()), 0)
659
660      # Suppress normal variable inits to make sure the local one is
661      # initialized via local_init_op.
662      sv = supervisor.Supervisor(logdir=logdir, init_op=None)
663      sess = sv.prepare_or_wait_for_session("")
664      self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
665      self.assertAllClose([4, 5, 6], sess.run(w))
666      sv.stop()
667
668  def testLocalInitOpForNonChief(self):
669    logdir = self._test_dir("default_local_init_op_non_chief")
670    with ops.Graph().as_default():
671      with ops.device("/job:localhost"):
672        # A local variable.
673        v = variables.VariableV1(
674            [1.0, 2.0, 3.0],
675            trainable=False,
676            collections=[ops.GraphKeys.LOCAL_VARIABLES])
677        # This shouldn't add a variable to the VARIABLES collection responsible
678        # for variables that are saved/restored from checkpoints.
679        self.assertEqual(len(variables.global_variables()), 0)
680
681      # Suppress normal variable inits to make sure the local one is
682      # initialized via local_init_op.
683      sv = supervisor.Supervisor(logdir=logdir, init_op=None, is_chief=False)
684      sess = sv.prepare_or_wait_for_session("")
685      self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
686      sv.stop()
687
688  def testInitOpFails(self):
689    server = server_lib.Server.create_local_server()
690    logdir = self._test_dir("default_init_op_fails")
691    with ops.Graph().as_default():
692      v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
693      variables.VariableV1([4.0, 5.0, 6.0], name="w")
694      # w will not be initialized.
695      sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer)
696      with self.assertRaisesRegex(RuntimeError, "Variables not initialized: w"):
697        sv.prepare_or_wait_for_session(server.target)
698
699  def testInitOpFailsForTransientVariable(self):
700    server = server_lib.Server.create_local_server()
701    logdir = self._test_dir("default_init_op_fails_for_local_variable")
702    with ops.Graph().as_default():
703      v = variables.VariableV1(
704          [1.0, 2.0, 3.0],
705          name="v",
706          collections=[ops.GraphKeys.LOCAL_VARIABLES])
707      variables.VariableV1(
708          [1.0, 2.0, 3.0],
709          name="w",
710          collections=[ops.GraphKeys.LOCAL_VARIABLES])
711      # w will not be initialized.
712      sv = supervisor.Supervisor(logdir=logdir, local_init_op=v.initializer)
713      with self.assertRaisesRegex(RuntimeError, "Variables not initialized: w"):
714        sv.prepare_or_wait_for_session(server.target)
715
716  @test_util.run_v1_only("train.Supervisor is for v1 only")
717  def testSetupFail(self):
718    logdir = self._test_dir("setup_fail")
719    with ops.Graph().as_default():
720      variables.VariableV1([1.0, 2.0, 3.0], name="v")
721      with self.assertRaisesRegex(ValueError, "must have their device set"):
722        supervisor.Supervisor(logdir=logdir, is_chief=False)
723    with ops.Graph().as_default(), ops.device("/job:ps"):
724      variables.VariableV1([1.0, 2.0, 3.0], name="v")
725      supervisor.Supervisor(logdir=logdir, is_chief=False)
726
727  @test_util.run_v1_only("train.Supervisor is for v1 only")
728  def testDefaultGlobalStep(self):
729    logdir = self._test_dir("default_global_step")
730    with ops.Graph().as_default():
731      variables.VariableV1(287, name="global_step")
732      sv = supervisor.Supervisor(logdir=logdir)
733      sess = sv.prepare_or_wait_for_session("")
734      self.assertEqual(287, sess.run(sv.global_step))
735      sv.stop()
736
737  @test_util.run_v1_only("train.Supervisor is for v1 only")
738  def testRestoreFromMetaGraph(self):
739    logdir = self._test_dir("restore_from_meta_graph")
740    with ops.Graph().as_default():
741      variables.VariableV1(1, name="v0")
742      sv = supervisor.Supervisor(logdir=logdir)
743      sess = sv.prepare_or_wait_for_session("")
744      filename = sv.saver.save(sess, sv.save_path)
745      sv.stop()
746    # Create a new Graph and Supervisor and recover.
747    with ops.Graph().as_default():
748      new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"]))
749      self.assertIsNotNone(new_saver)
750      sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
751      sess = sv2.prepare_or_wait_for_session("")
752      self.assertEqual(1, sess.run("v0:0"))
753      sv2.saver.save(sess, sv2.save_path)
754      sv2.stop()
755
756  # This test is based on the fact that the standard services start
757  # right away and get to run once before sv.stop() returns.
758  # We still sleep a bit to make the test robust.
759  @test_util.run_v1_only("train.Supervisor is for v1 only")
760  def testStandardServicesWithoutGlobalStep(self):
761    logdir = self._test_dir("standard_services_without_global_step")
762    # Create a checkpoint.
763    with ops.Graph().as_default():
764      v = variables.VariableV1([1.0], name="foo")
765      summary.scalar("v", v[0])
766      sv = supervisor.Supervisor(logdir=logdir)
767      meta_graph_def = meta_graph.create_meta_graph_def(
768          saver_def=sv.saver.saver_def)
769      sess = sv.prepare_or_wait_for_session("")
770      save_path = sv.save_path
771      self._wait_for_glob(save_path, 3.0)
772      self._wait_for_glob(
773          os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
774      # Wait to make sure everything is written to file before stopping.
775      time.sleep(1)
776      sv.stop()
777    # There should be an event file with a version number.
778    rr = _summary_iterator(logdir)
779    ev = next(rr)
780    self.assertEqual("brain.Event:2", ev.file_version)
781    ev = next(rr)
782    ev_graph = graph_pb2.GraphDef()
783    ev_graph.ParseFromString(ev.graph_def)
784    self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
785
786    # Stored MetaGraphDef
787    ev = next(rr)
788    ev_meta_graph = meta_graph_pb2.MetaGraphDef()
789    ev_meta_graph.ParseFromString(ev.meta_graph_def)
790    self.assertProtoEquals(meta_graph_def, ev_meta_graph)
791    self.assertProtoEquals(
792        sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
793
794    ev = next(rr)
795    self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary)
796
797    ev = next(rr)
798    self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
799
800    self.assertRaises(StopIteration, lambda: next(rr))
801    # There should be a checkpoint file with the variable "foo"
802    with ops.Graph().as_default(), self.cached_session() as sess:
803      v = variables.VariableV1([10.10], name="foo")
804      sav = saver_lib.Saver([v])
805      sav.restore(sess, save_path)
806      self.assertEqual(1.0, self.evaluate(v)[0])
807
808  # Same as testStandardServicesNoGlobalStep but with a global step.
809  # We should get a summary about the step time.
810  @test_util.run_v1_only("train.Supervisor is for v1 only")
811  def testStandardServicesWithGlobalStep(self):
812    logdir = self._test_dir("standard_services_with_global_step")
813    # Create a checkpoint.
814    with ops.Graph().as_default():
815      v = variables.VariableV1([123], name="global_step")
816      sv = supervisor.Supervisor(logdir=logdir)
817      meta_graph_def = meta_graph.create_meta_graph_def(
818          saver_def=sv.saver.saver_def)
819      sess = sv.prepare_or_wait_for_session("")
820      # This is where the checkpoint will appear, with step number 123.
821      save_path = "%s-123" % sv.save_path
822      self._wait_for_glob(save_path, 3.0)
823      self._wait_for_glob(
824          os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
825      # Wait to make sure everything is written to file before stopping.
826      time.sleep(1)
827      sv.stop()
828    # There should be an event file with a version number.
829    rr = _summary_iterator(logdir)
830    ev = next(rr)
831    self.assertEqual("brain.Event:2", ev.file_version)
832    ev = next(rr)
833    ev_graph = graph_pb2.GraphDef()
834    ev_graph.ParseFromString(ev.graph_def)
835    self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
836    ev = next(rr)
837    ev_meta_graph = meta_graph_pb2.MetaGraphDef()
838    ev_meta_graph.ParseFromString(ev.meta_graph_def)
839    self.assertProtoEquals(meta_graph_def, ev_meta_graph)
840    self.assertProtoEquals(
841        sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
842    ev = next(rr)
843    # It is actually undeterministic whether SessionLog.START gets written
844    # before the summary or the checkpoint, but this works when run 10000 times.
845    self.assertEqual(123, ev.step)
846    self.assertEqual(event_pb2.SessionLog.START, ev.session_log.status)
847    first = next(rr)
848    second = next(rr)
849    # It is undeterministic whether the value gets written before the checkpoint
850    # since they are on separate threads, so we check for both conditions.
851    if first.HasField("summary"):
852      self.assertProtoEquals("""value { tag: 'global_step/sec'
853                                        simple_value: 0.0 }""", first.summary)
854      self.assertEqual(123, second.step)
855      self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
856                       second.session_log.status)
857    else:
858      self.assertEqual(123, first.step)
859      self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
860                       first.session_log.status)
861      self.assertProtoEquals("""value { tag: 'global_step/sec'
862                                        simple_value: 0.0 }""", second.summary)
863    ev = next(rr)
864    self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
865    self.assertRaises(StopIteration, lambda: next(rr))
866    # There should be a checkpoint file with the variable "foo"
867    with ops.Graph().as_default(), self.cached_session() as sess:
868      v = variables.VariableV1([-12], name="global_step")
869      sav = saver_lib.Saver([v])
870      sav.restore(sess, save_path)
871      self.assertEqual(123, self.evaluate(v)[0])
872
873  def testNoQueueRunners(self):
874    with ops.Graph().as_default(), self.cached_session() as sess:
875      sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
876      self.assertEqual(0, len(sv.start_queue_runners(sess)))
877      sv.stop()
878
879  def testPrepareSessionAfterStopForChief(self):
880    logdir = self._test_dir("prepare_after_stop_chief")
881    with ops.Graph().as_default():
882      sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
883
884      # Create a first session and then stop.
885      sess = sv.prepare_or_wait_for_session("")
886      sv.stop()
887      sess.close()
888      self.assertTrue(sv.should_stop())
889
890      # Now create a second session and test that we don't stay stopped, until
891      # we ask to stop again.
892      sess2 = sv.prepare_or_wait_for_session("")
893      self.assertFalse(sv.should_stop())
894      sv.stop()
895      sess2.close()
896      self.assertTrue(sv.should_stop())
897
898  def testPrepareSessionAfterStopForNonChief(self):
899    logdir = self._test_dir("prepare_after_stop_nonchief")
900    with ops.Graph().as_default():
901      sv = supervisor.Supervisor(logdir=logdir, is_chief=False)
902
903      # Create a first session and then stop.
904      sess = sv.prepare_or_wait_for_session("")
905      sv.stop()
906      sess.close()
907      self.assertTrue(sv.should_stop())
908
909      # Now create a second session and test that we don't stay stopped, until
910      # we ask to stop again.
911      sess2 = sv.prepare_or_wait_for_session("")
912      self.assertFalse(sv.should_stop())
913      sv.stop()
914      sess2.close()
915      self.assertTrue(sv.should_stop())
916
917
918if __name__ == "__main__":
919  test.main()
920