1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for supervisor.py.""" 16 17import glob 18import os 19import shutil 20import time 21import uuid 22 23 24from tensorflow.core.framework import graph_pb2 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.core.protobuf import meta_graph_pb2 27from tensorflow.core.util import event_pb2 28from tensorflow.python.checkpoint import checkpoint_management 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors_impl 32from tensorflow.python.framework import meta_graph 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import io_ops 37from tensorflow.python.ops import parsing_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import gfile 40from tensorflow.python.platform import test 41from tensorflow.python.summary import summary 42from tensorflow.python.summary import summary_iterator 43from tensorflow.python.summary.writer import writer 44from tensorflow.python.training import input as input_lib 45from tensorflow.python.training import saver as saver_lib 46from tensorflow.python.training import server_lib 47from tensorflow.python.training import session_manager as session_manager_lib 48from tensorflow.python.training import supervisor 49 50 51def _summary_iterator(test_dir): 52 """Reads events from test_dir/events. 53 54 Args: 55 test_dir: Name of the test directory. 56 57 Returns: 58 A summary_iterator 59 """ 60 event_paths = sorted(glob.glob(os.path.join(test_dir, "event*"))) 61 return summary_iterator.summary_iterator(event_paths[-1]) 62 63 64class SupervisorTest(test.TestCase): 65 66 def _test_dir(self, test_name): 67 test_dir = os.path.join(self.get_temp_dir(), test_name) 68 if os.path.exists(test_dir): 69 shutil.rmtree(test_dir) 70 return test_dir 71 72 def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True): 73 """Wait for a checkpoint file to appear. 74 75 Args: 76 pattern: A string. 77 timeout_secs: How long to wait for in seconds. 78 for_checkpoint: whether we're globbing for checkpoints. 79 """ 80 end_time = time.time() + timeout_secs 81 while time.time() < end_time: 82 if for_checkpoint: 83 if checkpoint_management.checkpoint_exists(pattern): 84 return 85 else: 86 if len(gfile.Glob(pattern)) >= 1: 87 return 88 time.sleep(0.05) 89 self.assertFalse(True, "Glob never matched any file: %s" % pattern) 90 91 # This test does not test much. 92 def testBasics(self): 93 logdir = self._test_dir("basics") 94 with ops.Graph().as_default(): 95 my_op = constant_op.constant(1.0) 96 sv = supervisor.Supervisor(logdir=logdir) 97 sess = sv.prepare_or_wait_for_session("") 98 for _ in range(10): 99 self.evaluate(my_op) 100 sess.close() 101 sv.stop() 102 103 def testManagedSession(self): 104 logdir = self._test_dir("managed_session") 105 with ops.Graph().as_default(): 106 my_op = constant_op.constant(1.0) 107 sv = supervisor.Supervisor(logdir=logdir) 108 with sv.managed_session(""): 109 for _ in range(10): 110 self.evaluate(my_op) 111 # Supervisor has been stopped. 112 self.assertTrue(sv.should_stop()) 113 114 def testManagedSessionUserError(self): 115 logdir = self._test_dir("managed_user_error") 116 with ops.Graph().as_default(): 117 my_op = constant_op.constant(1.0) 118 sv = supervisor.Supervisor(logdir=logdir) 119 last_step = None 120 with self.assertRaisesRegex(RuntimeError, "failing here"): 121 with sv.managed_session("") as sess: 122 for step in range(10): 123 last_step = step 124 if step == 1: 125 raise RuntimeError("failing here") 126 else: 127 self.evaluate(my_op) 128 # Supervisor has been stopped. 129 self.assertTrue(sv.should_stop()) 130 self.assertEqual(1, last_step) 131 132 def testManagedSessionIgnoreOutOfRangeError(self): 133 logdir = self._test_dir("managed_out_of_range") 134 with ops.Graph().as_default(): 135 my_op = constant_op.constant(1.0) 136 sv = supervisor.Supervisor(logdir=logdir) 137 last_step = None 138 with sv.managed_session("") as sess: 139 for step in range(10): 140 last_step = step 141 if step == 3: 142 raise errors_impl.OutOfRangeError(my_op.op.node_def, my_op.op, 143 "all done") 144 else: 145 self.evaluate(my_op) 146 # Supervisor has been stopped. OutOfRangeError was not thrown. 147 self.assertTrue(sv.should_stop()) 148 self.assertEqual(3, last_step) 149 150 def testManagedSessionDoNotKeepSummaryWriter(self): 151 logdir = self._test_dir("managed_not_keep_summary_writer") 152 with ops.Graph().as_default(): 153 summary.scalar("c1", constant_op.constant(1)) 154 summary.scalar("c2", constant_op.constant(2)) 155 summary.scalar("c3", constant_op.constant(3)) 156 summ = summary.merge_all() 157 sv = supervisor.Supervisor(logdir=logdir, summary_op=None) 158 with sv.managed_session( 159 "", close_summary_writer=True, start_standard_services=False) as sess: 160 sv.summary_computed(sess, sess.run(summ)) 161 # Sleep 1.2s to make sure that the next event file has a different name 162 # than the current one. 163 time.sleep(1.2) 164 with sv.managed_session( 165 "", close_summary_writer=True, start_standard_services=False) as sess: 166 sv.summary_computed(sess, sess.run(summ)) 167 event_paths = sorted(glob.glob(os.path.join(logdir, "event*"))) 168 self.assertEqual(2, len(event_paths)) 169 # The two event files should have the same contents. 170 for path in event_paths: 171 # The summary iterator should report the summary once as we closed the 172 # summary writer across the 2 sessions. 173 rr = summary_iterator.summary_iterator(path) 174 # The first event should list the file_version. 175 ev = next(rr) 176 self.assertEqual("brain.Event:2", ev.file_version) 177 178 # The next one has the graph and metagraph. 179 ev = next(rr) 180 self.assertTrue(ev.graph_def) 181 182 ev = next(rr) 183 self.assertTrue(ev.meta_graph_def) 184 185 # The next one should have the values from the summary. 186 # But only once. 187 ev = next(rr) 188 self.assertProtoEquals(""" 189 value { tag: 'c1' simple_value: 1.0 } 190 value { tag: 'c2' simple_value: 2.0 } 191 value { tag: 'c3' simple_value: 3.0 } 192 """, ev.summary) 193 194 # The next one should be a stop message if we closed cleanly. 195 ev = next(rr) 196 self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status) 197 198 # We should be done. 199 with self.assertRaises(StopIteration): 200 next(rr) 201 202 def testManagedSessionKeepSummaryWriter(self): 203 logdir = self._test_dir("managed_keep_summary_writer") 204 with ops.Graph().as_default(): 205 summary.scalar("c1", constant_op.constant(1)) 206 summary.scalar("c2", constant_op.constant(2)) 207 summary.scalar("c3", constant_op.constant(3)) 208 summ = summary.merge_all() 209 sv = supervisor.Supervisor(logdir=logdir) 210 with sv.managed_session( 211 "", close_summary_writer=False, 212 start_standard_services=False) as sess: 213 sv.summary_computed(sess, sess.run(summ)) 214 with sv.managed_session( 215 "", close_summary_writer=False, 216 start_standard_services=False) as sess: 217 sv.summary_computed(sess, sess.run(summ)) 218 # Now close the summary writer to flush the events. 219 sv.summary_writer.close() 220 # The summary iterator should report the summary twice as we reused 221 # the same summary writer across the 2 sessions. 222 rr = _summary_iterator(logdir) 223 # The first event should list the file_version. 224 ev = next(rr) 225 self.assertEqual("brain.Event:2", ev.file_version) 226 227 # The next one has the graph. 228 ev = next(rr) 229 self.assertTrue(ev.graph_def) 230 231 ev = next(rr) 232 self.assertTrue(ev.meta_graph_def) 233 234 # The next one should have the values from the summary. 235 ev = next(rr) 236 self.assertProtoEquals(""" 237 value { tag: 'c1' simple_value: 1.0 } 238 value { tag: 'c2' simple_value: 2.0 } 239 value { tag: 'c3' simple_value: 3.0 } 240 """, ev.summary) 241 242 # The next one should also have the values from the summary. 243 ev = next(rr) 244 self.assertProtoEquals(""" 245 value { tag: 'c1' simple_value: 1.0 } 246 value { tag: 'c2' simple_value: 2.0 } 247 value { tag: 'c3' simple_value: 3.0 } 248 """, ev.summary) 249 250 # We should be done. 251 self.assertRaises(StopIteration, lambda: next(rr)) 252 253 def _csv_data(self, logdir): 254 # Create a small data file with 3 CSV records. 255 data_path = os.path.join(logdir, "data.csv") 256 with open(data_path, "w") as f: 257 f.write("1,2,3\n") 258 f.write("4,5,6\n") 259 f.write("7,8,9\n") 260 return data_path 261 262 def testManagedEndOfInputOneQueue(self): 263 # Tests that the supervisor finishes without an error when using 264 # a fixed number of epochs, reading from a single queue. 265 logdir = self._test_dir("managed_end_of_input_one_queue") 266 os.makedirs(logdir) 267 data_path = self._csv_data(logdir) 268 with ops.Graph().as_default(): 269 # Create an input pipeline that reads the file 3 times. 270 filename_queue = input_lib.string_input_producer( 271 [data_path], num_epochs=3) 272 reader = io_ops.TextLineReader() 273 _, csv = reader.read(filename_queue) 274 rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) 275 sv = supervisor.Supervisor(logdir=logdir) 276 with sv.managed_session("") as sess: 277 while not sv.should_stop(): 278 sess.run(rec) 279 280 def testManagedEndOfInputTwoQueues(self): 281 # Tests that the supervisor finishes without an error when using 282 # a fixed number of epochs, reading from two queues, the second 283 # one producing a batch from the first one. 284 logdir = self._test_dir("managed_end_of_input_two_queues") 285 os.makedirs(logdir) 286 data_path = self._csv_data(logdir) 287 with ops.Graph().as_default(): 288 # Create an input pipeline that reads the file 3 times. 289 filename_queue = input_lib.string_input_producer( 290 [data_path], num_epochs=3) 291 reader = io_ops.TextLineReader() 292 _, csv = reader.read(filename_queue) 293 rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) 294 shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) 295 sv = supervisor.Supervisor(logdir=logdir) 296 with sv.managed_session("") as sess: 297 while not sv.should_stop(): 298 sess.run(shuff_rec) 299 300 def testManagedMainErrorTwoQueues(self): 301 # Tests that the supervisor correctly raises a main loop 302 # error even when using multiple queues for input. 303 logdir = self._test_dir("managed_main_error_two_queues") 304 os.makedirs(logdir) 305 data_path = self._csv_data(logdir) 306 with self.assertRaisesRegex(RuntimeError, "fail at step 3"): 307 with ops.Graph().as_default(): 308 # Create an input pipeline that reads the file 3 times. 309 filename_queue = input_lib.string_input_producer( 310 [data_path], num_epochs=3) 311 reader = io_ops.TextLineReader() 312 _, csv = reader.read(filename_queue) 313 rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) 314 shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) 315 sv = supervisor.Supervisor(logdir=logdir) 316 with sv.managed_session("") as sess: 317 for step in range(9): 318 if sv.should_stop(): 319 break 320 elif step == 3: 321 raise RuntimeError("fail at step 3") 322 else: 323 sess.run(shuff_rec) 324 325 def testSessionConfig(self): 326 logdir = self._test_dir("session_config") 327 with ops.Graph().as_default(): 328 with ops.device("/cpu:1"): 329 my_op = constant_op.constant([1.0]) 330 sv = supervisor.Supervisor(logdir=logdir) 331 sess = sv.prepare_or_wait_for_session( 332 "", config=config_pb2.ConfigProto(device_count={"CPU": 2})) 333 for _ in range(10): 334 self.evaluate(my_op) 335 sess.close() 336 sv.stop() 337 338 def testChiefCanWriteEvents(self): 339 logdir = self._test_dir("can_write") 340 with ops.Graph().as_default(): 341 summary.scalar("c1", constant_op.constant(1)) 342 summary.scalar("c2", constant_op.constant(2)) 343 summary.scalar("c3", constant_op.constant(3)) 344 summ = summary.merge_all() 345 sv = supervisor.Supervisor(is_chief=True, logdir=logdir, summary_op=None) 346 meta_graph_def = meta_graph.create_meta_graph_def() 347 sess = sv.prepare_or_wait_for_session("") 348 sv.summary_computed(sess, sess.run(summ)) 349 sess.close() 350 # Wait to make sure everything is written to file before stopping. 351 time.sleep(1) 352 sv.stop() 353 354 rr = _summary_iterator(logdir) 355 356 # The first event should list the file_version. 357 ev = next(rr) 358 self.assertEqual("brain.Event:2", ev.file_version) 359 360 # The next one has the graph. 361 ev = next(rr) 362 ev_graph = graph_pb2.GraphDef() 363 ev_graph.ParseFromString(ev.graph_def) 364 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 365 366 # Stored MetaGraphDef 367 ev = next(rr) 368 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 369 ev_meta_graph.ParseFromString(ev.meta_graph_def) 370 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 371 self.assertProtoEquals( 372 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 373 # The next one should have the values from the summary. 374 ev = next(rr) 375 self.assertProtoEquals(""" 376 value { tag: 'c1' simple_value: 1.0 } 377 value { tag: 'c2' simple_value: 2.0 } 378 value { tag: 'c3' simple_value: 3.0 } 379 """, ev.summary) 380 381 # The next one should be a stop message if we closed cleanly. 382 ev = next(rr) 383 self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status) 384 385 # We should be done. 386 self.assertRaises(StopIteration, lambda: next(rr)) 387 388 def testNonChiefCannotWriteEvents(self): 389 390 def _summary_computed(): 391 with ops.Graph().as_default(): 392 sv = supervisor.Supervisor(is_chief=False) 393 sess = sv.prepare_or_wait_for_session("") 394 summary.scalar("c1", constant_op.constant(1)) 395 summary.scalar("c2", constant_op.constant(2)) 396 summ = summary.merge_all() 397 sv.summary_computed(sess, sess.run(summ)) 398 399 def _start_standard_services(): 400 with ops.Graph().as_default(): 401 sv = supervisor.Supervisor(is_chief=False) 402 sess = sv.prepare_or_wait_for_session("") 403 sv.start_standard_services(sess) 404 405 self.assertRaises(RuntimeError, _summary_computed) 406 self.assertRaises(RuntimeError, _start_standard_services) 407 408 def testNoLogdirButWantSummary(self): 409 with ops.Graph().as_default(): 410 summary.scalar("c1", constant_op.constant(1)) 411 summary.scalar("c2", constant_op.constant(2)) 412 summary.scalar("c3", constant_op.constant(3)) 413 summ = summary.merge_all() 414 sv = supervisor.Supervisor(logdir="", summary_op=None) 415 sess = sv.prepare_or_wait_for_session("") 416 with self.assertRaisesRegex(RuntimeError, "requires a summary writer"): 417 sv.summary_computed(sess, sess.run(summ)) 418 419 @test_util.run_v1_only("train.Supervisor is for v1 only") 420 def testLogdirButExplicitlyNoSummaryWriter(self): 421 logdir = self._test_dir("explicit_no_summary_writer") 422 with ops.Graph().as_default(): 423 variables.VariableV1([1.0], name="foo") 424 summary.scalar("c1", constant_op.constant(1)) 425 summary.scalar("c2", constant_op.constant(2)) 426 summary.scalar("c3", constant_op.constant(3)) 427 summ = summary.merge_all() 428 sv = supervisor.Supervisor(logdir=logdir, summary_writer=None) 429 sess = sv.prepare_or_wait_for_session("") 430 # Check that a checkpoint is still be generated. 431 self._wait_for_glob(sv.save_path, 3.0) 432 # Check that we cannot write a summary 433 with self.assertRaisesRegex(RuntimeError, "requires a summary writer"): 434 sv.summary_computed(sess, sess.run(summ)) 435 436 def testNoLogdirButExplicitSummaryWriter(self): 437 logdir = self._test_dir("explicit_summary_writer") 438 with ops.Graph().as_default(): 439 summary.scalar("c1", constant_op.constant(1)) 440 summary.scalar("c2", constant_op.constant(2)) 441 summary.scalar("c3", constant_op.constant(3)) 442 summ = summary.merge_all() 443 sw = writer.FileWriter(logdir) 444 sv = supervisor.Supervisor(logdir="", summary_op=None, summary_writer=sw) 445 meta_graph_def = meta_graph.create_meta_graph_def() 446 sess = sv.prepare_or_wait_for_session("") 447 sv.summary_computed(sess, sess.run(summ)) 448 sess.close() 449 # Wait to make sure everything is written to file before stopping. 450 time.sleep(1) 451 sv.stop() 452 453 # Check the summary was written to 'logdir' 454 rr = _summary_iterator(logdir) 455 456 # The first event should list the file_version. 457 ev = next(rr) 458 self.assertEqual("brain.Event:2", ev.file_version) 459 460 # The next one has the graph. 461 ev = next(rr) 462 ev_graph = graph_pb2.GraphDef() 463 ev_graph.ParseFromString(ev.graph_def) 464 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 465 466 # Stored MetaGraphDef 467 ev = next(rr) 468 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 469 ev_meta_graph.ParseFromString(ev.meta_graph_def) 470 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 471 self.assertProtoEquals( 472 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 473 474 # The next one should have the values from the summary. 475 ev = next(rr) 476 self.assertProtoEquals(""" 477 value { tag: 'c1' simple_value: 1.0 } 478 value { tag: 'c2' simple_value: 2.0 } 479 value { tag: 'c3' simple_value: 3.0 } 480 """, ev.summary) 481 482 # The next one should be a stop message if we closed cleanly. 483 ev = next(rr) 484 self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status) 485 486 # We should be done. 487 self.assertRaises(StopIteration, lambda: next(rr)) 488 489 def testNoLogdirSucceeds(self): 490 with ops.Graph().as_default(): 491 variables.VariableV1([1.0, 2.0, 3.0]) 492 sv = supervisor.Supervisor(logdir="", summary_op=None) 493 sess = sv.prepare_or_wait_for_session("") 494 sess.close() 495 sv.stop() 496 497 def testUseSessionManager(self): 498 with ops.Graph().as_default(): 499 variables.VariableV1([1.0, 2.0, 3.0]) 500 sm = session_manager_lib.SessionManager() 501 # Pass in session_manager. The additional init_op is ignored. 502 sv = supervisor.Supervisor(logdir="", session_manager=sm) 503 sv.prepare_or_wait_for_session("") 504 505 @test_util.run_v1_only("train.Supervisor is for v1 only") 506 def testInitOp(self): 507 logdir = self._test_dir("default_init_op") 508 with ops.Graph().as_default(): 509 v = variables.VariableV1([1.0, 2.0, 3.0]) 510 sv = supervisor.Supervisor(logdir=logdir) 511 sess = sv.prepare_or_wait_for_session("") 512 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 513 sv.stop() 514 515 @test_util.run_v1_only("train.Supervisor is for v1 only") 516 def testInitFn(self): 517 logdir = self._test_dir("default_init_op") 518 with ops.Graph().as_default(): 519 v = variables.VariableV1([1.0, 2.0, 3.0]) 520 521 def _init_fn(sess): 522 sess.run(v.initializer) 523 524 sv = supervisor.Supervisor(logdir=logdir, init_op=None, init_fn=_init_fn) 525 sess = sv.prepare_or_wait_for_session("") 526 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 527 sv.stop() 528 529 @test_util.run_v1_only("train.Supervisor is for v1 only") 530 def testInitOpWithFeedDict(self): 531 logdir = self._test_dir("feed_dict_init_op") 532 with ops.Graph().as_default(): 533 p = array_ops.placeholder(dtypes.float32, shape=(3,)) 534 v = variables.VariableV1(p, name="v") 535 sv = supervisor.Supervisor( 536 logdir=logdir, 537 init_op=variables.global_variables_initializer(), 538 init_feed_dict={p: [1.0, 2.0, 3.0]}) 539 sess = sv.prepare_or_wait_for_session("") 540 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 541 sv.stop() 542 543 @test_util.run_v1_only("train.Supervisor is for v1 only") 544 def testReadyForLocalInitOp(self): 545 server = server_lib.Server.create_local_server() 546 logdir = self._test_dir("default_ready_for_local_init_op") 547 548 uid = uuid.uuid4().hex 549 550 def get_session(is_chief): 551 g = ops.Graph() 552 with g.as_default(): 553 with ops.device("/job:localhost"): 554 v = variables.VariableV1( 555 1, name="default_ready_for_local_init_op_v_" + str(uid)) 556 vadd = v.assign_add(1) 557 w = variables.VariableV1( 558 v, 559 trainable=False, 560 collections=[ops.GraphKeys.LOCAL_VARIABLES], 561 name="default_ready_for_local_init_op_w_" + str(uid)) 562 ready_for_local_init_op = variables.report_uninitialized_variables( 563 variables.global_variables()) 564 sv = supervisor.Supervisor( 565 logdir=logdir, 566 is_chief=is_chief, 567 graph=g, 568 recovery_wait_secs=1, 569 init_op=v.initializer, 570 ready_for_local_init_op=ready_for_local_init_op) 571 sess = sv.prepare_or_wait_for_session(server.target) 572 573 return sv, sess, v, vadd, w 574 575 sv0, sess0, v0, _, w0 = get_session(True) 576 sv1, sess1, _, vadd1, w1 = get_session(False) 577 578 self.assertEqual(1, sess0.run(w0)) 579 self.assertEqual(2, sess1.run(vadd1)) 580 self.assertEqual(1, sess1.run(w1)) 581 self.assertEqual(2, sess0.run(v0)) 582 583 sv0.stop() 584 sv1.stop() 585 586 @test_util.run_v1_only("train.Supervisor is for v1 only") 587 def testReadyForLocalInitOpRestoreFromCheckpoint(self): 588 server = server_lib.Server.create_local_server() 589 logdir = self._test_dir("ready_for_local_init_op_restore") 590 591 uid = uuid.uuid4().hex 592 593 # Create a checkpoint. 594 with ops.Graph().as_default(): 595 v = variables.VariableV1( 596 10.0, name="ready_for_local_init_op_restore_v_" + str(uid)) 597 summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v) 598 sv = supervisor.Supervisor(logdir=logdir) 599 sv.prepare_or_wait_for_session(server.target) 600 save_path = sv.save_path 601 self._wait_for_glob(save_path, 3.0) 602 self._wait_for_glob( 603 os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) 604 # Wait to make sure everything is written to file before stopping. 605 time.sleep(1) 606 sv.stop() 607 608 def get_session(is_chief): 609 g = ops.Graph() 610 with g.as_default(): 611 with ops.device("/job:localhost"): 612 v = variables.VariableV1( 613 1.0, name="ready_for_local_init_op_restore_v_" + str(uid)) 614 vadd = v.assign_add(1) 615 w = variables.VariableV1( 616 v, 617 trainable=False, 618 collections=[ops.GraphKeys.LOCAL_VARIABLES], 619 name="ready_for_local_init_op_restore_w_" + str(uid)) 620 ready_for_local_init_op = variables.report_uninitialized_variables( 621 variables.global_variables()) 622 sv = supervisor.Supervisor( 623 logdir=logdir, 624 is_chief=is_chief, 625 graph=g, 626 recovery_wait_secs=1, 627 ready_for_local_init_op=ready_for_local_init_op) 628 sess = sv.prepare_or_wait_for_session(server.target) 629 630 return sv, sess, v, vadd, w 631 632 sv0, sess0, v0, _, w0 = get_session(True) 633 sv1, sess1, _, vadd1, w1 = get_session(False) 634 635 self.assertEqual(10, sess0.run(w0)) 636 self.assertEqual(11, sess1.run(vadd1)) 637 self.assertEqual(10, sess1.run(w1)) 638 self.assertEqual(11, sess0.run(v0)) 639 640 sv0.stop() 641 sv1.stop() 642 643 def testLocalInitOp(self): 644 logdir = self._test_dir("default_local_init_op") 645 with ops.Graph().as_default(): 646 # A local variable. 647 v = variables.VariableV1( 648 [1.0, 2.0, 3.0], 649 trainable=False, 650 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 651 652 # An entity which is initialized through a TABLE_INITIALIZER. 653 w = variables.VariableV1([4, 5, 6], trainable=False, collections=[]) 654 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer) 655 656 # This shouldn't add a variable to the VARIABLES collection responsible 657 # for variables that are saved/restored from checkpoints. 658 self.assertEqual(len(variables.global_variables()), 0) 659 660 # Suppress normal variable inits to make sure the local one is 661 # initialized via local_init_op. 662 sv = supervisor.Supervisor(logdir=logdir, init_op=None) 663 sess = sv.prepare_or_wait_for_session("") 664 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 665 self.assertAllClose([4, 5, 6], sess.run(w)) 666 sv.stop() 667 668 def testLocalInitOpForNonChief(self): 669 logdir = self._test_dir("default_local_init_op_non_chief") 670 with ops.Graph().as_default(): 671 with ops.device("/job:localhost"): 672 # A local variable. 673 v = variables.VariableV1( 674 [1.0, 2.0, 3.0], 675 trainable=False, 676 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 677 # This shouldn't add a variable to the VARIABLES collection responsible 678 # for variables that are saved/restored from checkpoints. 679 self.assertEqual(len(variables.global_variables()), 0) 680 681 # Suppress normal variable inits to make sure the local one is 682 # initialized via local_init_op. 683 sv = supervisor.Supervisor(logdir=logdir, init_op=None, is_chief=False) 684 sess = sv.prepare_or_wait_for_session("") 685 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 686 sv.stop() 687 688 def testInitOpFails(self): 689 server = server_lib.Server.create_local_server() 690 logdir = self._test_dir("default_init_op_fails") 691 with ops.Graph().as_default(): 692 v = variables.VariableV1([1.0, 2.0, 3.0], name="v") 693 variables.VariableV1([4.0, 5.0, 6.0], name="w") 694 # w will not be initialized. 695 sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer) 696 with self.assertRaisesRegex(RuntimeError, "Variables not initialized: w"): 697 sv.prepare_or_wait_for_session(server.target) 698 699 def testInitOpFailsForTransientVariable(self): 700 server = server_lib.Server.create_local_server() 701 logdir = self._test_dir("default_init_op_fails_for_local_variable") 702 with ops.Graph().as_default(): 703 v = variables.VariableV1( 704 [1.0, 2.0, 3.0], 705 name="v", 706 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 707 variables.VariableV1( 708 [1.0, 2.0, 3.0], 709 name="w", 710 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 711 # w will not be initialized. 712 sv = supervisor.Supervisor(logdir=logdir, local_init_op=v.initializer) 713 with self.assertRaisesRegex(RuntimeError, "Variables not initialized: w"): 714 sv.prepare_or_wait_for_session(server.target) 715 716 @test_util.run_v1_only("train.Supervisor is for v1 only") 717 def testSetupFail(self): 718 logdir = self._test_dir("setup_fail") 719 with ops.Graph().as_default(): 720 variables.VariableV1([1.0, 2.0, 3.0], name="v") 721 with self.assertRaisesRegex(ValueError, "must have their device set"): 722 supervisor.Supervisor(logdir=logdir, is_chief=False) 723 with ops.Graph().as_default(), ops.device("/job:ps"): 724 variables.VariableV1([1.0, 2.0, 3.0], name="v") 725 supervisor.Supervisor(logdir=logdir, is_chief=False) 726 727 @test_util.run_v1_only("train.Supervisor is for v1 only") 728 def testDefaultGlobalStep(self): 729 logdir = self._test_dir("default_global_step") 730 with ops.Graph().as_default(): 731 variables.VariableV1(287, name="global_step") 732 sv = supervisor.Supervisor(logdir=logdir) 733 sess = sv.prepare_or_wait_for_session("") 734 self.assertEqual(287, sess.run(sv.global_step)) 735 sv.stop() 736 737 @test_util.run_v1_only("train.Supervisor is for v1 only") 738 def testRestoreFromMetaGraph(self): 739 logdir = self._test_dir("restore_from_meta_graph") 740 with ops.Graph().as_default(): 741 variables.VariableV1(1, name="v0") 742 sv = supervisor.Supervisor(logdir=logdir) 743 sess = sv.prepare_or_wait_for_session("") 744 filename = sv.saver.save(sess, sv.save_path) 745 sv.stop() 746 # Create a new Graph and Supervisor and recover. 747 with ops.Graph().as_default(): 748 new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"])) 749 self.assertIsNotNone(new_saver) 750 sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver) 751 sess = sv2.prepare_or_wait_for_session("") 752 self.assertEqual(1, sess.run("v0:0")) 753 sv2.saver.save(sess, sv2.save_path) 754 sv2.stop() 755 756 # This test is based on the fact that the standard services start 757 # right away and get to run once before sv.stop() returns. 758 # We still sleep a bit to make the test robust. 759 @test_util.run_v1_only("train.Supervisor is for v1 only") 760 def testStandardServicesWithoutGlobalStep(self): 761 logdir = self._test_dir("standard_services_without_global_step") 762 # Create a checkpoint. 763 with ops.Graph().as_default(): 764 v = variables.VariableV1([1.0], name="foo") 765 summary.scalar("v", v[0]) 766 sv = supervisor.Supervisor(logdir=logdir) 767 meta_graph_def = meta_graph.create_meta_graph_def( 768 saver_def=sv.saver.saver_def) 769 sess = sv.prepare_or_wait_for_session("") 770 save_path = sv.save_path 771 self._wait_for_glob(save_path, 3.0) 772 self._wait_for_glob( 773 os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) 774 # Wait to make sure everything is written to file before stopping. 775 time.sleep(1) 776 sv.stop() 777 # There should be an event file with a version number. 778 rr = _summary_iterator(logdir) 779 ev = next(rr) 780 self.assertEqual("brain.Event:2", ev.file_version) 781 ev = next(rr) 782 ev_graph = graph_pb2.GraphDef() 783 ev_graph.ParseFromString(ev.graph_def) 784 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 785 786 # Stored MetaGraphDef 787 ev = next(rr) 788 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 789 ev_meta_graph.ParseFromString(ev.meta_graph_def) 790 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 791 self.assertProtoEquals( 792 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 793 794 ev = next(rr) 795 self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary) 796 797 ev = next(rr) 798 self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status) 799 800 self.assertRaises(StopIteration, lambda: next(rr)) 801 # There should be a checkpoint file with the variable "foo" 802 with ops.Graph().as_default(), self.cached_session() as sess: 803 v = variables.VariableV1([10.10], name="foo") 804 sav = saver_lib.Saver([v]) 805 sav.restore(sess, save_path) 806 self.assertEqual(1.0, self.evaluate(v)[0]) 807 808 # Same as testStandardServicesNoGlobalStep but with a global step. 809 # We should get a summary about the step time. 810 @test_util.run_v1_only("train.Supervisor is for v1 only") 811 def testStandardServicesWithGlobalStep(self): 812 logdir = self._test_dir("standard_services_with_global_step") 813 # Create a checkpoint. 814 with ops.Graph().as_default(): 815 v = variables.VariableV1([123], name="global_step") 816 sv = supervisor.Supervisor(logdir=logdir) 817 meta_graph_def = meta_graph.create_meta_graph_def( 818 saver_def=sv.saver.saver_def) 819 sess = sv.prepare_or_wait_for_session("") 820 # This is where the checkpoint will appear, with step number 123. 821 save_path = "%s-123" % sv.save_path 822 self._wait_for_glob(save_path, 3.0) 823 self._wait_for_glob( 824 os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) 825 # Wait to make sure everything is written to file before stopping. 826 time.sleep(1) 827 sv.stop() 828 # There should be an event file with a version number. 829 rr = _summary_iterator(logdir) 830 ev = next(rr) 831 self.assertEqual("brain.Event:2", ev.file_version) 832 ev = next(rr) 833 ev_graph = graph_pb2.GraphDef() 834 ev_graph.ParseFromString(ev.graph_def) 835 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 836 ev = next(rr) 837 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 838 ev_meta_graph.ParseFromString(ev.meta_graph_def) 839 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 840 self.assertProtoEquals( 841 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 842 ev = next(rr) 843 # It is actually undeterministic whether SessionLog.START gets written 844 # before the summary or the checkpoint, but this works when run 10000 times. 845 self.assertEqual(123, ev.step) 846 self.assertEqual(event_pb2.SessionLog.START, ev.session_log.status) 847 first = next(rr) 848 second = next(rr) 849 # It is undeterministic whether the value gets written before the checkpoint 850 # since they are on separate threads, so we check for both conditions. 851 if first.HasField("summary"): 852 self.assertProtoEquals("""value { tag: 'global_step/sec' 853 simple_value: 0.0 }""", first.summary) 854 self.assertEqual(123, second.step) 855 self.assertEqual(event_pb2.SessionLog.CHECKPOINT, 856 second.session_log.status) 857 else: 858 self.assertEqual(123, first.step) 859 self.assertEqual(event_pb2.SessionLog.CHECKPOINT, 860 first.session_log.status) 861 self.assertProtoEquals("""value { tag: 'global_step/sec' 862 simple_value: 0.0 }""", second.summary) 863 ev = next(rr) 864 self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status) 865 self.assertRaises(StopIteration, lambda: next(rr)) 866 # There should be a checkpoint file with the variable "foo" 867 with ops.Graph().as_default(), self.cached_session() as sess: 868 v = variables.VariableV1([-12], name="global_step") 869 sav = saver_lib.Saver([v]) 870 sav.restore(sess, save_path) 871 self.assertEqual(123, self.evaluate(v)[0]) 872 873 def testNoQueueRunners(self): 874 with ops.Graph().as_default(), self.cached_session() as sess: 875 sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners")) 876 self.assertEqual(0, len(sv.start_queue_runners(sess))) 877 sv.stop() 878 879 def testPrepareSessionAfterStopForChief(self): 880 logdir = self._test_dir("prepare_after_stop_chief") 881 with ops.Graph().as_default(): 882 sv = supervisor.Supervisor(logdir=logdir, is_chief=True) 883 884 # Create a first session and then stop. 885 sess = sv.prepare_or_wait_for_session("") 886 sv.stop() 887 sess.close() 888 self.assertTrue(sv.should_stop()) 889 890 # Now create a second session and test that we don't stay stopped, until 891 # we ask to stop again. 892 sess2 = sv.prepare_or_wait_for_session("") 893 self.assertFalse(sv.should_stop()) 894 sv.stop() 895 sess2.close() 896 self.assertTrue(sv.should_stop()) 897 898 def testPrepareSessionAfterStopForNonChief(self): 899 logdir = self._test_dir("prepare_after_stop_nonchief") 900 with ops.Graph().as_default(): 901 sv = supervisor.Supervisor(logdir=logdir, is_chief=False) 902 903 # Create a first session and then stop. 904 sess = sv.prepare_or_wait_for_session("") 905 sv.stop() 906 sess.close() 907 self.assertTrue(sv.should_stop()) 908 909 # Now create a second session and test that we don't stay stopped, until 910 # we ask to stop again. 911 sess2 = sv.prepare_or_wait_for_session("") 912 self.assertFalse(sv.should_stop()) 913 sv.stop() 914 sess2.close() 915 self.assertTrue(sv.should_stop()) 916 917 918if __name__ == "__main__": 919 test.main() 920