1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training helper that checkpoints models and computes summaries.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21import os 22import time 23 24from tensorflow.core.framework.summary_pb2 import Summary 25from tensorflow.core.util.event_pb2 import SessionLog 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import meta_graph 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import lookup_ops 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.summary import summary as _summary 35from tensorflow.python.training import coordinator 36from tensorflow.python.training import saver as saver_mod 37from tensorflow.python.training import session_manager as session_manager_mod 38from tensorflow.python.training import training_util 39from tensorflow.python.util import deprecation 40from tensorflow.python.util.tf_export import tf_export 41 42 43@tf_export(v1=["train.Supervisor"]) 44class Supervisor(object): 45 """A training helper that checkpoints models and computes summaries. 46 47 This class is deprecated. Please use 48 `tf.train.MonitoredTrainingSession` instead. 49 50 The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, 51 and a `SessionManager` that takes care of common needs of TensorFlow 52 training programs. 53 54 #### Use for a single program 55 56 ```python 57 with tf.Graph().as_default(): 58 ...add operations to the graph... 59 # Create a Supervisor that will checkpoint the model in '/tmp/mydir'. 60 sv = Supervisor(logdir='/tmp/mydir') 61 # Get a TensorFlow session managed by the supervisor. 62 with sv.managed_session(FLAGS.master) as sess: 63 # Use the session to train the graph. 64 while not sv.should_stop(): 65 sess.run(<my_train_op>) 66 ``` 67 68 Within the `with sv.managed_session()` block all variables in the graph have 69 been initialized. In addition, a few services have been started to 70 checkpoint the model and add summaries to the event log. 71 72 If the program crashes and is restarted, the managed session automatically 73 reinitialize variables from the most recent checkpoint. 74 75 The supervisor is notified of any exception raised by one of the services. 76 After an exception is raised, `should_stop()` returns `True`. In that case 77 the training loop should also stop. This is why the training loop has to 78 check for `sv.should_stop()`. 79 80 Exceptions that indicate that the training inputs have been exhausted, 81 `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True` 82 but are not re-raised from the `with` block: they indicate a normal 83 termination. 84 85 #### Use for multiple replicas 86 87 To train with replicas you deploy the same program in a `Cluster`. 88 One of the tasks must be identified as the *chief*: the task that handles 89 initialization, checkpoints, summaries, and recovery. The other tasks 90 depend on the *chief* for these services. 91 92 The only change you have to do to the single program code is to indicate 93 if the program is running as the *chief*. 94 95 ```python 96 # Choose a task as the chief. This could be based on server_def.task_index, 97 # or job_def.name, or job_def.tasks. It's entirely up to the end user. 98 # But there can be only one *chief*. 99 is_chief = (server_def.task_index == 0) 100 server = tf.train.Server(server_def) 101 102 with tf.Graph().as_default(): 103 ...add operations to the graph... 104 # Create a Supervisor that uses log directory on a shared file system. 105 # Indicate if you are the 'chief' 106 sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief) 107 # Get a Session in a TensorFlow server on the cluster. 108 with sv.managed_session(server.target) as sess: 109 # Use the session to train the graph. 110 while not sv.should_stop(): 111 sess.run(<my_train_op>) 112 ``` 113 114 In the *chief* task, the `Supervisor` works exactly as in the first example 115 above. In the other tasks `sv.managed_session()` waits for the Model to have 116 been initialized before returning a session to the training code. The 117 non-chief tasks depend on the chief task for initializing the model. 118 119 If one of the tasks crashes and restarts, `managed_session()` 120 checks if the Model is initialized. If yes, it just creates a session and 121 returns it to the training code that proceeds normally. If the model needs 122 to be initialized, the chief task takes care of reinitializing it; the other 123 tasks just wait for the model to have been initialized. 124 125 NOTE: This modified program still works fine as a single program. 126 The single program marks itself as the chief. 127 128 #### What `master` string to use 129 130 Whether you are running on your machine or in the cluster you can use the 131 following values for the --master flag: 132 133 * Specifying `''` requests an in-process session that does not use RPC. 134 135 * Specifying `'local'` requests a session that uses the RPC-based 136 "Master interface" to run TensorFlow programs. See 137 `tf.train.Server.create_local_server` for 138 details. 139 140 * Specifying `'grpc://hostname:port'` requests a session that uses 141 the RPC interface to a specific host, and also allows the in-process 142 master to access remote tensorflow workers. Often, it is 143 appropriate to pass `server.target` (for some `tf.train.Server` 144 named `server). 145 146 #### Advanced use 147 148 ##### Launching additional services 149 150 `managed_session()` launches the Checkpoint and Summary services (threads). 151 If you need more services to run you can simply launch them in the block 152 controlled by `managed_session()`. 153 154 Example: Start a thread to print losses. We want this thread to run 155 every 60 seconds, so we launch it with `sv.loop()`. 156 157 ```python 158 ... 159 sv = Supervisor(logdir='/tmp/mydir') 160 with sv.managed_session(FLAGS.master) as sess: 161 sv.loop(60, print_loss, (sess, )) 162 while not sv.should_stop(): 163 sess.run(my_train_op) 164 ``` 165 166 ##### Launching fewer services 167 168 `managed_session()` launches the "summary" and "checkpoint" threads which use 169 either the optionally `summary_op` and `saver` passed to the constructor, or 170 default ones created automatically by the supervisor. If you want to run 171 your own summary and checkpointing logic, disable these services by passing 172 `None` to the `summary_op` and `saver` parameters. 173 174 Example: Create summaries manually every 100 steps in the chief. 175 176 ```python 177 # Create a Supervisor with no automatic summaries. 178 sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None) 179 # As summary_op was None, managed_session() does not start the 180 # summary thread. 181 with sv.managed_session(FLAGS.master) as sess: 182 for step in xrange(1000000): 183 if sv.should_stop(): 184 break 185 if is_chief and step % 100 == 0: 186 # Create the summary every 100 chief steps. 187 sv.summary_computed(sess, sess.run(my_summary_op)) 188 else: 189 # Train normally 190 sess.run(my_train_op) 191 ``` 192 193 ##### Custom model initialization 194 195 `managed_session()` only supports initializing the model by running an 196 `init_op` or restoring from the latest checkpoint. If you have special 197 initialization needs, see how to specify a `local_init_op` when creating the 198 supervisor. You can also use the `SessionManager` directly to create a 199 session and check if it could be initialized automatically. 200 """ 201 202 # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver', 203 # and 'global_step' parameters of Supervisor.__init__() to indicate that 204 # the default behavior should be used. 205 USE_DEFAULT = 0 206 207 @deprecation.deprecated(None, 208 "Please switch to tf.train.MonitoredTrainingSession") 209 def __init__(self, 210 graph=None, 211 ready_op=USE_DEFAULT, 212 ready_for_local_init_op=USE_DEFAULT, 213 is_chief=True, 214 init_op=USE_DEFAULT, 215 init_feed_dict=None, 216 local_init_op=USE_DEFAULT, 217 logdir=None, 218 summary_op=USE_DEFAULT, 219 saver=USE_DEFAULT, 220 global_step=USE_DEFAULT, 221 save_summaries_secs=120, 222 save_model_secs=600, 223 recovery_wait_secs=30, 224 stop_grace_secs=120, 225 checkpoint_basename="model.ckpt", 226 session_manager=None, 227 summary_writer=USE_DEFAULT, 228 init_fn=None, 229 local_init_run_options=None): 230 """Create a `Supervisor`. 231 232 Args: 233 graph: A `Graph`. The graph that the model will use. Defaults to the 234 default `Graph`. The supervisor may add operations to the graph before 235 creating a session, but the graph should not be modified by the caller 236 after passing it to the supervisor. 237 ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in 238 `prepare_or_wait_for_session()` to check if the model is ready to use. 239 The model is considered ready if it returns an empty array. Defaults to 240 the tensor returned from `tf.report_uninitialized_variables()` If 241 `None`, the model is not checked for readiness. 242 ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by 243 supervisors in `prepare_or_wait_for_session()` to check if the model is 244 ready to run the local_init_op. 245 The model is considered ready if it returns an empty array. Defaults to 246 `None`. If `None`, the model is not checked for readiness before running 247 local_init_op. 248 is_chief: If True, create a chief supervisor in charge of initializing 249 and restoring the model. If False, create a supervisor that relies 250 on a chief supervisor for inits and restore. 251 init_op: `Operation`. Used by chief supervisors to initialize the model 252 when it can not be recovered. Defaults to an `Operation` that 253 initializes all global variables. If `None`, no initialization is done 254 automatically unless you pass a value for `init_fn`, see below. 255 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 256 This feed dictionary will be used when `init_op` is evaluated. 257 local_init_op: `Operation`. Used by all supervisors to run initializations 258 that should run for every new supervisor instance. By default these 259 are table initializers and initializers for local variables. 260 If `None`, no further per supervisor-instance initialization is 261 done automatically. 262 logdir: A string. Optional path to a directory where to checkpoint the 263 model and log events for the visualizer. Used by chief supervisors. 264 The directory will be created if it does not exist. 265 summary_op: An `Operation` that returns a Summary for the event logs. 266 Used by chief supervisors if a `logdir` was specified. Defaults to the 267 operation returned from summary.merge_all(). If `None`, summaries are 268 not computed automatically. 269 saver: A Saver object. Used by chief supervisors if a `logdir` was 270 specified. Defaults to the saved returned by Saver(). 271 If `None`, the model is not saved automatically. 272 global_step: An integer Tensor of size 1 that counts steps. The value 273 from 'global_step' is used in summaries and checkpoint filenames. 274 Default to the op named 'global_step' in the graph if it exists, is of 275 rank 1, size 1, and of type tf.int32 or tf.int64. If `None` the global 276 step is not recorded in summaries and checkpoint files. Used by chief 277 supervisors if a `logdir` was specified. 278 save_summaries_secs: Number of seconds between the computation of 279 summaries for the event log. Defaults to 120 seconds. Pass 0 to 280 disable summaries. 281 save_model_secs: Number of seconds between the creation of model 282 checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints. 283 recovery_wait_secs: Number of seconds between checks that the model 284 is ready. Used by supervisors when waiting for a chief supervisor 285 to initialize or restore the model. Defaults to 30 seconds. 286 stop_grace_secs: Grace period, in seconds, given to running threads to 287 stop when `stop()` is called. Defaults to 120 seconds. 288 checkpoint_basename: The basename for checkpoint saving. 289 session_manager: `SessionManager`, which manages Session creation and 290 recovery. If it is `None`, a default `SessionManager` will be created 291 with the set of arguments passed in for backwards compatibility. 292 summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` 293 to indicate that no summaries should be written. 294 init_fn: Optional callable used to initialize the model. Called 295 after the optional `init_op` is called. The callable must accept one 296 argument, the session being initialized. 297 local_init_run_options: RunOptions to be passed as the SessionManager 298 local_init_run_options parameter. 299 300 Returns: 301 A `Supervisor`. 302 303 Raises: 304 RuntimeError: If called with eager execution enabled. 305 306 @compatibility(eager) 307 `Supervisor`s are not supported when eager execution is enabled. 308 @end_compatibility 309 """ 310 if context.executing_eagerly(): 311 raise RuntimeError("Supervisors are compatible with eager execution.") 312 # Set default values of arguments. 313 if graph is None: 314 graph = ops.get_default_graph() 315 with graph.as_default(): 316 self._init_ready_op( 317 ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op) 318 self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict) 319 self._init_local_init_op(local_init_op=local_init_op) 320 self._init_saver(saver=saver) 321 self._init_summary_op(summary_op=summary_op) 322 self._init_global_step(global_step=global_step) 323 self._graph = graph 324 self._meta_graph_def = meta_graph.create_meta_graph_def( 325 graph_def=graph.as_graph_def(add_shapes=True), 326 saver_def=self._saver.saver_def if self._saver else None) 327 self._is_chief = is_chief 328 self._coord = coordinator.Coordinator() 329 self._recovery_wait_secs = recovery_wait_secs 330 self._stop_grace_secs = stop_grace_secs 331 self._init_fn = init_fn 332 self._local_init_run_options = local_init_run_options 333 334 # Set all attributes related to checkpointing and writing events to None. 335 # Afterwards, set them appropriately for chief supervisors, as these are 336 # the only supervisors that can write checkpoints and events. 337 self._logdir = None 338 self._save_summaries_secs = None 339 self._save_model_secs = None 340 self._save_path = None 341 self._summary_writer = None 342 343 if self._is_chief: 344 self._logdir = logdir 345 self._save_summaries_secs = save_summaries_secs 346 self._save_model_secs = save_model_secs 347 if self._logdir: 348 self._save_path = os.path.join(self._logdir, checkpoint_basename) 349 if summary_writer is Supervisor.USE_DEFAULT: 350 if self._logdir: 351 self._summary_writer = _summary.FileWriter(self._logdir) 352 else: 353 self._summary_writer = summary_writer 354 self._graph_added_to_summary = False 355 356 self._init_session_manager(session_manager=session_manager) 357 self._verify_setup() 358 # The graph is not allowed to change anymore. 359 graph.finalize() 360 361 def _init_session_manager(self, session_manager=None): 362 if session_manager is None: 363 self._session_manager = session_manager_mod.SessionManager( 364 local_init_op=self._local_init_op, 365 ready_op=self._ready_op, 366 ready_for_local_init_op=self._ready_for_local_init_op, 367 graph=self._graph, 368 recovery_wait_secs=self._recovery_wait_secs, 369 local_init_run_options=self._local_init_run_options) 370 else: 371 self._session_manager = session_manager 372 373 def _get_first_op_from_collection(self, key): 374 """Returns the first `Operation` from a collection. 375 376 Args: 377 key: A string collection key. 378 379 Returns: 380 The first Op found in a collection, or `None` if the collection is empty. 381 """ 382 try: 383 op_list = ops.get_collection(key) 384 if len(op_list) > 1: 385 logging.info("Found %d %s operations. Returning the first one.", 386 len(op_list), key) 387 if op_list: 388 return op_list[0] 389 except LookupError: 390 pass 391 392 return None 393 394 def _init_ready_op(self, 395 ready_op=USE_DEFAULT, 396 ready_for_local_init_op=USE_DEFAULT): 397 """Initializes ready_op. 398 399 Args: 400 ready_op: `Tensor` to check if the model is initialized. 401 If it's set to USE_DEFAULT, creates an op that checks all 402 the variables are initialized. 403 ready_for_local_init_op: `Tensor` to check if the model is ready to run 404 local_init_op. 405 If it's set to USE_DEFAULT, creates an op that checks all 406 the global variables are initialized. 407 """ 408 if ready_op is Supervisor.USE_DEFAULT: 409 ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP) 410 if ready_op is None: 411 ready_op = variables.report_uninitialized_variables() 412 ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op) 413 self._ready_op = ready_op 414 415 # ready_for_local_init_op defaults to None for backward compatibility 416 if ready_for_local_init_op is Supervisor.USE_DEFAULT: 417 ready_for_local_init_op = self._get_first_op_from_collection( 418 ops.GraphKeys.READY_FOR_LOCAL_INIT_OP) 419 self._ready_for_local_init_op = ready_for_local_init_op 420 421 def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None): 422 """Initializes init_op. 423 424 Args: 425 init_op: `Operation` to initialize the variables. If set to USE_DEFAULT, 426 create an op that initializes all variables and tables. 427 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 428 This feed dictionary will be used when `init_op` is evaluated. 429 """ 430 if init_op is Supervisor.USE_DEFAULT: 431 init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP) 432 if init_op is None: 433 init_op = variables.global_variables_initializer() 434 ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op) 435 self._init_op = init_op 436 self._init_feed_dict = init_feed_dict 437 438 def _init_local_init_op(self, local_init_op=USE_DEFAULT): 439 """Initializes local_init_op. 440 441 Args: 442 local_init_op: `Operation` run for every new supervisor instance. If set 443 to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP 444 collection. If the collection is empty, create an op that initializes 445 all local variables and all tables. 446 """ 447 if local_init_op is Supervisor.USE_DEFAULT: 448 local_init_op = self._get_first_op_from_collection( 449 ops.GraphKeys.LOCAL_INIT_OP) 450 if local_init_op is None: 451 op_list = [ 452 variables.local_variables_initializer(), 453 lookup_ops.tables_initializer() 454 ] 455 if op_list: 456 local_init_op = control_flow_ops.group(*op_list) 457 ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) 458 self._local_init_op = local_init_op 459 460 def _init_saver(self, saver=USE_DEFAULT): 461 """Initializes saver. 462 463 Args: 464 saver: A `Saver` object. If set to USE_DEFAULT, create one that 465 saves all the variables. 466 """ 467 if saver is Supervisor.USE_DEFAULT: 468 saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS) 469 if saver is None and variables.global_variables(): 470 saver = saver_mod.Saver() 471 ops.add_to_collection(ops.GraphKeys.SAVERS, saver) 472 self._saver = saver 473 474 def _init_summary_op(self, summary_op=USE_DEFAULT): 475 """Initializes summary_op. 476 477 Args: 478 summary_op: An Operation that returns a Summary for the event logs. 479 If set to USE_DEFAULT, create an op that merges all the summaries. 480 """ 481 if summary_op is Supervisor.USE_DEFAULT: 482 summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP) 483 if summary_op is None: 484 summary_op = _summary.merge_all() 485 if summary_op is not None: 486 ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op) 487 self._summary_op = summary_op 488 489 def _init_global_step(self, global_step=USE_DEFAULT): 490 """Initializes global_step. 491 492 Args: 493 global_step: An integer Tensor of size 1 that counts steps. If 494 set to USE_DEFAULT, creates global_step tensor. 495 """ 496 if global_step is Supervisor.USE_DEFAULT: 497 global_step = self._get_first_op_from_collection( 498 ops.GraphKeys.GLOBAL_STEP) 499 if global_step is None: 500 global_step = self._default_global_step_tensor() 501 if global_step is not None: 502 ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step) 503 self._global_step = global_step 504 505 @property 506 def is_chief(self): 507 """Return True if this is a chief supervisor. 508 509 Returns: 510 A bool. 511 """ 512 return self._is_chief 513 514 @property 515 def session_manager(self): 516 """Return the SessionManager used by the Supervisor. 517 518 Returns: 519 A SessionManager object. 520 """ 521 return self._session_manager 522 523 @property 524 def coord(self): 525 """Return the Coordinator used by the Supervisor. 526 527 The Coordinator can be useful if you want to run multiple threads 528 during your training. 529 530 Returns: 531 A Coordinator object. 532 """ 533 return self._coord 534 535 @property 536 def init_op(self): 537 """Return the Init Op used by the supervisor. 538 539 Returns: 540 An Op or `None`. 541 """ 542 return self._init_op 543 544 @property 545 def init_feed_dict(self): 546 """Return the feed dictionary used when evaluating the `init_op`. 547 548 Returns: 549 A feed dictionary or `None`. 550 """ 551 return self._init_feed_dict 552 553 @property 554 def ready_op(self): 555 """Return the Ready Op used by the supervisor. 556 557 Returns: 558 An Op or `None`. 559 """ 560 return self._ready_op 561 562 @property 563 def ready_for_local_init_op(self): 564 return self._ready_for_local_init_op 565 566 @property 567 def summary_writer(self): 568 """Return the SummaryWriter used by the chief supervisor. 569 570 Returns: 571 A SummaryWriter. 572 """ 573 return self._summary_writer 574 575 @property 576 def summary_op(self): 577 """Return the Summary Tensor used by the chief supervisor. 578 579 Returns: 580 A string Tensor for the summary or `None`. 581 """ 582 return self._summary_op 583 584 @property 585 def save_summaries_secs(self): 586 """Return the delay between summary computations. 587 588 Returns: 589 A timestamp. 590 """ 591 return self._save_summaries_secs 592 593 @property 594 def global_step(self): 595 """Return the global_step Tensor used by the supervisor. 596 597 Returns: 598 An integer Tensor for the global_step. 599 """ 600 return self._global_step 601 602 @property 603 def saver(self): 604 """Return the Saver used by the supervisor. 605 606 Returns: 607 A Saver object. 608 """ 609 return self._saver 610 611 @property 612 def save_model_secs(self): 613 """Return the delay between checkpoints. 614 615 Returns: 616 A timestamp. 617 """ 618 return self._save_model_secs 619 620 @property 621 def save_path(self): 622 """Return the save path used by the supervisor. 623 624 Returns: 625 A string. 626 """ 627 return self._save_path 628 629 def _write_graph(self): 630 """Writes graph_def to `logdir` and adds it to summary if applicable.""" 631 assert self._is_chief 632 if self._logdir: 633 training_util.write_graph(self._graph.as_graph_def(add_shapes=True), 634 self._logdir, "graph.pbtxt") 635 if self._summary_writer and not self._graph_added_to_summary: 636 self._summary_writer.add_graph(self._graph) 637 self._summary_writer.add_meta_graph(self._meta_graph_def) 638 self._graph_added_to_summary = True 639 640 def start_standard_services(self, sess): 641 """Start the standard services for 'sess'. 642 643 This starts services in the background. The services started depend 644 on the parameters to the constructor and may include: 645 646 - A Summary thread computing summaries every save_summaries_secs. 647 - A Checkpoint thread saving the model every save_model_secs. 648 - A StepCounter thread measure step time. 649 650 Args: 651 sess: A Session. 652 653 Returns: 654 A list of threads that are running the standard services. You can use 655 the Supervisor's Coordinator to join these threads with: 656 sv.coord.Join(<list of threads>) 657 658 Raises: 659 RuntimeError: If called with a non-chief Supervisor. 660 ValueError: If not `logdir` was passed to the constructor as the 661 services need a log directory. 662 """ 663 if not self._is_chief: 664 raise RuntimeError("Only chief supervisor can start standard services. " 665 "Because only chief supervisors can write events.") 666 667 if not self._logdir: 668 logging.warning("Standard services need a 'logdir' " 669 "passed to the SessionManager") 670 return 671 672 if self._global_step is not None and self._summary_writer: 673 # Only add the session log if we keep track of global step. 674 # TensorBoard cannot use START message for purging expired events 675 # if there is no step value. 676 current_step = training_util.global_step(sess, self._global_step) 677 self._summary_writer.add_session_log( 678 SessionLog(status=SessionLog.START), 679 current_step) 680 681 threads = [] 682 if self._save_summaries_secs and self._summary_writer: 683 if self._summary_op is not None: 684 threads.append(SVSummaryThread(self, sess)) 685 if self._global_step is not None: 686 threads.append(SVStepCounterThread(self, sess)) 687 if self.saver and self._save_model_secs: 688 threads.append(SVTimerCheckpointThread(self, sess)) 689 for t in threads: 690 t.start() 691 return threads 692 693 def prepare_or_wait_for_session(self, master="", config=None, 694 wait_for_checkpoint=False, 695 max_wait_secs=7200, 696 start_standard_services=True): 697 """Make sure the model is ready to be used. 698 699 Create a session on 'master', recovering or initializing the model as 700 needed, or wait for a session to be ready. If running as the chief 701 and `start_standard_service` is set to True, also call the session 702 manager to start the standard services. 703 704 Args: 705 master: name of the TensorFlow master to use. See the `tf.Session` 706 constructor for how this is interpreted. 707 config: Optional ConfigProto proto used to configure the session, 708 which is passed as-is to create the session. 709 wait_for_checkpoint: Whether we should wait for the availability of a 710 checkpoint before creating Session. Defaults to False. 711 max_wait_secs: Maximum time to wait for the session to become available. 712 start_standard_services: Whether to start the standard services and the 713 queue runners. 714 715 Returns: 716 A Session object that can be used to drive the model. 717 """ 718 # For users who recreate the session with prepare_or_wait_for_session(), we 719 # need to clear the coordinator's stop_event so that threads managed by the 720 # coordinator can run. 721 self._coord.clear_stop() 722 if self._summary_writer: 723 self._summary_writer.reopen() 724 725 if self._is_chief: 726 sess = self._session_manager.prepare_session( 727 master, init_op=self.init_op, saver=self.saver, 728 checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint, 729 max_wait_secs=max_wait_secs, config=config, 730 init_feed_dict=self._init_feed_dict, init_fn=self._init_fn) 731 self._write_graph() 732 if start_standard_services: 733 logging.info("Starting standard services.") 734 self.start_standard_services(sess) 735 else: 736 sess = self._session_manager.wait_for_session(master, 737 config=config, 738 max_wait_secs=max_wait_secs) 739 if start_standard_services: 740 logging.info("Starting queue runners.") 741 self.start_queue_runners(sess) 742 return sess 743 744 def start_queue_runners(self, sess, queue_runners=None): 745 """Start threads for `QueueRunners`. 746 747 Note that the queue runners collected in the graph key `QUEUE_RUNNERS` 748 are already started automatically when you create a session with the 749 supervisor, so unless you have non-collected queue runners to start 750 you do not need to call this explicitly. 751 752 Args: 753 sess: A `Session`. 754 queue_runners: A list of `QueueRunners`. If not specified, we'll use the 755 list of queue runners gathered in the graph under the key 756 `GraphKeys.QUEUE_RUNNERS`. 757 758 Returns: 759 The list of threads started for the `QueueRunners`. 760 761 Raises: 762 RuntimeError: If called with eager execution enabled. 763 764 @compatibility(eager) 765 Queues are not compatible with eager execution. To ingest data when eager 766 execution is enabled, use the `tf.data` API. 767 @end_compatibility 768 """ 769 if context.executing_eagerly(): 770 raise RuntimeError("Queues are not compatible with eager execution.") 771 if queue_runners is None: 772 queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) 773 threads = [] 774 for qr in queue_runners: 775 threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True, 776 start=True)) 777 return threads 778 779 def loop(self, timer_interval_secs, target, args=None, kwargs=None): 780 """Start a LooperThread that calls a function periodically. 781 782 If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)` 783 repeatedly. Otherwise it calls it every `timer_interval_secs` 784 seconds. The thread terminates when a stop is requested. 785 786 The started thread is added to the list of threads managed by the supervisor 787 so it does not need to be passed to the `stop()` method. 788 789 Args: 790 timer_interval_secs: Number. Time boundaries at which to call `target`. 791 target: A callable object. 792 args: Optional arguments to pass to `target` when calling it. 793 kwargs: Optional keyword arguments to pass to `target` when calling it. 794 795 Returns: 796 The started thread. 797 """ 798 looper = coordinator.LooperThread(self._coord, timer_interval_secs, 799 target=target, args=args, kwargs=kwargs) 800 looper.start() 801 return looper 802 803 def stop(self, 804 threads=None, 805 close_summary_writer=True, 806 ignore_live_threads=False): 807 """Stop the services and the coordinator. 808 809 This does not close the session. 810 811 Args: 812 threads: Optional list of threads to join with the coordinator. If 813 `None`, defaults to the threads running the standard services, the 814 threads started for `QueueRunners`, and the threads started by the 815 `loop()` method. To wait on additional threads, pass the 816 list in this parameter. 817 close_summary_writer: Whether to close the `summary_writer`. Defaults to 818 `True` if the summary writer was created by the supervisor, `False` 819 otherwise. 820 ignore_live_threads: If `True` ignores threads that remain running after 821 a grace period when joining threads via the coordinator, instead of 822 raising a RuntimeError. 823 """ 824 self._coord.request_stop() 825 try: 826 # coord.join() re-raises the first reported exception; the "finally" 827 # block ensures that we clean up whether or not an exception was 828 # reported. 829 self._coord.join( 830 threads, 831 stop_grace_period_secs=self._stop_grace_secs, 832 ignore_live_threads=ignore_live_threads) 833 finally: 834 # Close the writer last, in case one of the running threads was using it. 835 if close_summary_writer and self._summary_writer: 836 # Stop messages are not logged with event.step, 837 # since the session may have already terminated. 838 self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP)) 839 self._summary_writer.close() 840 self._graph_added_to_summary = False 841 842 def request_stop(self, ex=None): 843 """Request that the coordinator stop the threads. 844 845 See `Coordinator.request_stop()`. 846 847 Args: 848 ex: Optional `Exception`, or Python `exc_info` tuple as returned by 849 `sys.exc_info()`. If this is the first call to `request_stop()` the 850 corresponding exception is recorded and re-raised from `join()`. 851 """ 852 self._coord.request_stop(ex=ex) 853 854 def should_stop(self): 855 """Check if the coordinator was told to stop. 856 857 See `Coordinator.should_stop()`. 858 859 Returns: 860 True if the coordinator was told to stop, False otherwise. 861 """ 862 return self._coord.should_stop() 863 864 def stop_on_exception(self): 865 """Context handler to stop the supervisor when an exception is raised. 866 867 See `Coordinator.stop_on_exception()`. 868 869 Returns: 870 A context handler. 871 """ 872 return self._coord.stop_on_exception() 873 874 def wait_for_stop(self): 875 """Block waiting for the coordinator to stop.""" 876 self._coord.wait_for_stop() 877 878 def summary_computed(self, sess, summary, global_step=None): 879 """Indicate that a summary was computed. 880 881 Args: 882 sess: A `Session` object. 883 summary: A Summary proto, or a string holding a serialized summary proto. 884 global_step: Int. global step this summary is associated with. If `None`, 885 it will try to fetch the current step. 886 887 Raises: 888 TypeError: if 'summary' is not a Summary proto or a string. 889 RuntimeError: if the Supervisor was created without a `logdir`. 890 """ 891 if not self._summary_writer: 892 raise RuntimeError("Writing a summary requires a summary writer.") 893 if global_step is None and self.global_step is not None: 894 global_step = training_util.global_step(sess, self.global_step) 895 self._summary_writer.add_summary(summary, global_step) 896 897 def _default_global_step_tensor(self): 898 """Returns the global_step from the default graph. 899 900 Returns: 901 The global step `Tensor` or `None`. 902 """ 903 try: 904 gs = ops.get_default_graph().get_tensor_by_name("global_step:0") 905 if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: 906 return gs 907 else: 908 logging.warning("Found 'global_step' is not an int type: %s", gs.dtype) 909 return None 910 except KeyError: 911 return None 912 913 def _verify_setup(self): 914 """Check that all is good. 915 916 Raises: 917 ValueError: If something is not good. 918 """ 919 # Not running as chief means that replicas are used. 920 # In that case all Variables must have their device set. 921 if not self._is_chief: 922 for op in self._graph.get_operations(): 923 if op.type in ["Variable", "VariableV2"] and not op.device: 924 raise ValueError("When using replicas, all Variables must have " 925 "their device set: %s" % op) 926 927 # pylint: disable=g-doc-return-or-yield,broad-except 928 @contextlib.contextmanager 929 def managed_session(self, master="", config=None, 930 start_standard_services=True, 931 close_summary_writer=True): 932 """Returns a context manager for a managed session. 933 934 This context manager creates and automatically recovers a session. It 935 optionally starts the standard services that handle checkpoints and 936 summaries. It monitors exceptions raised from the `with` block or from the 937 services and stops the supervisor as needed. 938 939 The context manager is typically used as follows: 940 941 ```python 942 def train(): 943 sv = tf.train.Supervisor(...) 944 with sv.managed_session(<master>) as sess: 945 for step in xrange(..): 946 if sv.should_stop(): 947 break 948 sess.run(<my training op>) 949 ...do other things needed at each training step... 950 ``` 951 952 An exception raised from the `with` block or one of the service threads is 953 raised again when the block exits. This is done after stopping all threads 954 and closing the session. For example, an `AbortedError` exception, raised 955 in case of preemption of one of the workers in a distributed model, is 956 raised again when the block exits. 957 958 If you want to retry the training loop in case of preemption you can do it 959 as follows: 960 961 ```python 962 def main(...): 963 while True 964 try: 965 train() 966 except tf.errors.Aborted: 967 pass 968 ``` 969 970 As a special case, exceptions used for control flow, such as 971 `OutOfRangeError` which reports that input queues are exhausted, are not 972 raised again from the `with` block: they indicate a clean termination of 973 the training loop and are considered normal termination. 974 975 Args: 976 master: name of the TensorFlow master to use. See the `tf.Session` 977 constructor for how this is interpreted. 978 config: Optional `ConfigProto` proto used to configure the session. 979 Passed as-is to create the session. 980 start_standard_services: Whether to start the standard services, 981 such as checkpoint, summary and step counter. 982 close_summary_writer: Whether to close the summary writer when 983 closing the session. Defaults to True. 984 985 Returns: 986 A context manager that yields a `Session` restored from the latest 987 checkpoint or initialized from scratch if not checkpoint exists. The 988 session is closed when the `with` block exits. 989 """ 990 try: 991 sess = self.prepare_or_wait_for_session( 992 master=master, config=config, 993 start_standard_services=start_standard_services) 994 yield sess 995 except Exception as e: 996 self.request_stop(e) 997 finally: 998 try: 999 # Request all the threads to stop and wait for them to do so. Any 1000 # exception raised by the threads is raised again from stop(). 1001 # Passing stop_grace_period_secs is for blocked enqueue/dequeue 1002 # threads which are not checking for `should_stop()`. They 1003 # will be stopped when we close the session further down. 1004 self.stop(close_summary_writer=close_summary_writer) 1005 finally: 1006 # Close the session to finish up all pending calls. We do not care 1007 # about exceptions raised when closing. This takes care of 1008 # blocked enqueue/dequeue calls. 1009 try: 1010 sess.close() 1011 except Exception: 1012 # Silently ignore exceptions raised by close(). 1013 pass 1014 # pylint: enable=g-doc-return-or-yield,broad-except 1015 1016 1017class SVSummaryThread(coordinator.LooperThread): 1018 """A thread to save summaries on a timer.""" 1019 1020 def __init__(self, sv, sess): 1021 """Create a SVSummaryThread. 1022 1023 Args: 1024 sv: A `Supervisor`. 1025 sess: A `Session`. 1026 """ 1027 super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs) 1028 self._sv = sv 1029 self._sess = sess 1030 1031 def run_loop(self): 1032 if self._sv.global_step is not None: 1033 summary_strs, global_step = self._sess.run([self._sv.summary_op, 1034 self._sv.global_step]) 1035 else: 1036 summary_strs = self._sess.run(self._sv.summary_op) 1037 global_step = None 1038 if self._sv.summary_writer: 1039 logging.info("Recording summary at step %s.", global_step) 1040 self._sv.summary_writer.add_summary(summary_strs, global_step) 1041 1042 1043class SVStepCounterThread(coordinator.LooperThread): 1044 """Threads to count steps and measure their duration.""" 1045 1046 def __init__(self, sv, sess, step_counter=None): 1047 """Create a `SVStepCounterThread`. 1048 1049 Args: 1050 sv: A `Supervisor`. 1051 sess: A `Session`. 1052 step_counter: A `Tensor` holding the step counter. By defaults, it uses 1053 sv.global_step. 1054 """ 1055 super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs) 1056 self._sv = sv 1057 self._sess = sess 1058 self._last_time = 0.0 1059 self._last_step = 0 1060 step_counter = sv.global_step if step_counter is None else step_counter 1061 self._step_counter = step_counter 1062 self._summary_tag = "%s/sec" % self._step_counter.op.name 1063 1064 def start_loop(self): 1065 self._last_time = time.time() 1066 self._last_step = training_util.global_step( 1067 self._sess, self._step_counter) 1068 1069 def run_loop(self): 1070 # Count the steps. 1071 current_step = training_util.global_step(self._sess, self._step_counter) 1072 added_steps = current_step - self._last_step 1073 self._last_step = current_step 1074 # Measure the elapsed time. 1075 current_time = time.time() 1076 elapsed_time = current_time - self._last_time 1077 self._last_time = current_time 1078 # Reports the number of steps done per second 1079 if elapsed_time > 0.: 1080 steps_per_sec = added_steps / elapsed_time 1081 else: 1082 steps_per_sec = float("inf") 1083 summary = Summary(value=[Summary.Value(tag=self._summary_tag, 1084 simple_value=steps_per_sec)]) 1085 if self._sv.summary_writer: 1086 self._sv.summary_writer.add_summary(summary, current_step) 1087 logging.log_first_n(logging.INFO, "%s: %g", 10, 1088 self._summary_tag, steps_per_sec) 1089 1090 1091class SVTimerCheckpointThread(coordinator.LooperThread): 1092 """A thread to checkpoint on a timer.""" 1093 1094 def __init__(self, sv, sess): 1095 """Create a `SVTimerCheckpointThread`. 1096 1097 Args: 1098 sv: A `Supervisor`. 1099 sess: A `Session`. 1100 """ 1101 super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs) 1102 self._sv = sv 1103 self._sess = sess 1104 1105 def run_loop(self): 1106 logging.info("Saving checkpoint to path %s", self._sv.save_path) 1107 self._sv.saver.save(self._sess, self._sv.save_path, 1108 global_step=self._sv.global_step) 1109 if self._sv.summary_writer and self._sv.global_step is not None: 1110 current_step = training_util.global_step(self._sess, self._sv.global_step) 1111 self._sv.summary_writer.add_session_log( 1112 SessionLog(status=SessionLog.CHECKPOINT, 1113 checkpoint_path=self._sv.save_path), 1114 current_step) 1115 1116 1117# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly. 1118setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session) 1119setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners) 1120setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services) 1121setattr(Supervisor, "Stop", Supervisor.stop) 1122setattr(Supervisor, "RequestStop", Supervisor.request_stop) 1123setattr(Supervisor, "Loop", Supervisor.loop) 1124setattr(Supervisor, "ShouldStop", Supervisor.should_stop) 1125setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception) 1126setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop) 1127setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed) 1128