1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training helper that checkpoints models and computes summaries.""" 16import contextlib 17import os 18import time 19 20from tensorflow.core.framework.summary_pb2 import Summary 21from tensorflow.core.util.event_pb2 import SessionLog 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import meta_graph 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import lookup_ops 28from tensorflow.python.ops import variables 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.summary import summary as _summary 31from tensorflow.python.training import coordinator 32from tensorflow.python.training import saver as saver_mod 33from tensorflow.python.training import session_manager as session_manager_mod 34from tensorflow.python.training import training_util 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.tf_export import tf_export 37 38 39@tf_export(v1=["train.Supervisor"]) 40class Supervisor: 41 """A training helper that checkpoints models and computes summaries. 42 43 This class is deprecated. Please use 44 `tf.compat.v1.train.MonitoredTrainingSession` instead. 45 46 The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, 47 and a `SessionManager` that takes care of common needs of TensorFlow 48 training programs. 49 50 #### Use for a single program 51 52 ```python 53 with tf.Graph().as_default(): 54 ...add operations to the graph... 55 # Create a Supervisor that will checkpoint the model in '/tmp/mydir'. 56 sv = Supervisor(logdir='/tmp/mydir') 57 # Get a TensorFlow session managed by the supervisor. 58 with sv.managed_session(FLAGS.master) as sess: 59 # Use the session to train the graph. 60 while not sv.should_stop(): 61 sess.run(<my_train_op>) 62 ``` 63 64 Within the `with sv.managed_session()` block all variables in the graph have 65 been initialized. In addition, a few services have been started to 66 checkpoint the model and add summaries to the event log. 67 68 If the program crashes and is restarted, the managed session automatically 69 reinitialize variables from the most recent checkpoint. 70 71 The supervisor is notified of any exception raised by one of the services. 72 After an exception is raised, `should_stop()` returns `True`. In that case 73 the training loop should also stop. This is why the training loop has to 74 check for `sv.should_stop()`. 75 76 Exceptions that indicate that the training inputs have been exhausted, 77 `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True` 78 but are not re-raised from the `with` block: they indicate a normal 79 termination. 80 81 #### Use for multiple replicas 82 83 To train with replicas you deploy the same program in a `Cluster`. 84 One of the tasks must be identified as the *chief*: the task that handles 85 initialization, checkpoints, summaries, and recovery. The other tasks 86 depend on the *chief* for these services. 87 88 The only change you have to do to the single program code is to indicate 89 if the program is running as the *chief*. 90 91 ```python 92 # Choose a task as the chief. This could be based on server_def.task_index, 93 # or job_def.name, or job_def.tasks. It's entirely up to the end user. 94 # But there can be only one *chief*. 95 is_chief = (server_def.task_index == 0) 96 server = tf.distribute.Server(server_def) 97 98 with tf.Graph().as_default(): 99 ...add operations to the graph... 100 # Create a Supervisor that uses log directory on a shared file system. 101 # Indicate if you are the 'chief' 102 sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief) 103 # Get a Session in a TensorFlow server on the cluster. 104 with sv.managed_session(server.target) as sess: 105 # Use the session to train the graph. 106 while not sv.should_stop(): 107 sess.run(<my_train_op>) 108 ``` 109 110 In the *chief* task, the `Supervisor` works exactly as in the first example 111 above. In the other tasks `sv.managed_session()` waits for the Model to have 112 been initialized before returning a session to the training code. The 113 non-chief tasks depend on the chief task for initializing the model. 114 115 If one of the tasks crashes and restarts, `managed_session()` 116 checks if the Model is initialized. If yes, it just creates a session and 117 returns it to the training code that proceeds normally. If the model needs 118 to be initialized, the chief task takes care of reinitializing it; the other 119 tasks just wait for the model to have been initialized. 120 121 NOTE: This modified program still works fine as a single program. 122 The single program marks itself as the chief. 123 124 #### What `master` string to use 125 126 Whether you are running on your machine or in the cluster you can use the 127 following values for the --master flag: 128 129 * Specifying `''` requests an in-process session that does not use RPC. 130 131 * Specifying `'local'` requests a session that uses the RPC-based 132 "Master interface" to run TensorFlow programs. See 133 `tf.train.Server.create_local_server` for 134 details. 135 136 * Specifying `'grpc://hostname:port'` requests a session that uses 137 the RPC interface to a specific host, and also allows the in-process 138 master to access remote tensorflow workers. Often, it is 139 appropriate to pass `server.target` (for some `tf.distribute.Server` 140 named `server). 141 142 #### Advanced use 143 144 ##### Launching additional services 145 146 `managed_session()` launches the Checkpoint and Summary services (threads). 147 If you need more services to run you can simply launch them in the block 148 controlled by `managed_session()`. 149 150 Example: Start a thread to print losses. We want this thread to run 151 every 60 seconds, so we launch it with `sv.loop()`. 152 153 ```python 154 ... 155 sv = Supervisor(logdir='/tmp/mydir') 156 with sv.managed_session(FLAGS.master) as sess: 157 sv.loop(60, print_loss, (sess, )) 158 while not sv.should_stop(): 159 sess.run(my_train_op) 160 ``` 161 162 ##### Launching fewer services 163 164 `managed_session()` launches the "summary" and "checkpoint" threads which use 165 either the optionally `summary_op` and `saver` passed to the constructor, or 166 default ones created automatically by the supervisor. If you want to run 167 your own summary and checkpointing logic, disable these services by passing 168 `None` to the `summary_op` and `saver` parameters. 169 170 Example: Create summaries manually every 100 steps in the chief. 171 172 ```python 173 # Create a Supervisor with no automatic summaries. 174 sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None) 175 # As summary_op was None, managed_session() does not start the 176 # summary thread. 177 with sv.managed_session(FLAGS.master) as sess: 178 for step in range(1000000): 179 if sv.should_stop(): 180 break 181 if is_chief and step % 100 == 0: 182 # Create the summary every 100 chief steps. 183 sv.summary_computed(sess, sess.run(my_summary_op)) 184 else: 185 # Train normally 186 sess.run(my_train_op) 187 ``` 188 189 ##### Custom model initialization 190 191 `managed_session()` only supports initializing the model by running an 192 `init_op` or restoring from the latest checkpoint. If you have special 193 initialization needs, see how to specify a `local_init_op` when creating the 194 supervisor. You can also use the `SessionManager` directly to create a 195 session and check if it could be initialized automatically. 196 """ 197 198 # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver', 199 # and 'global_step' parameters of Supervisor.__init__() to indicate that 200 # the default behavior should be used. 201 USE_DEFAULT = 0 202 203 @deprecation.deprecated(None, 204 "Please switch to tf.train.MonitoredTrainingSession") 205 def __init__(self, 206 graph=None, 207 ready_op=USE_DEFAULT, 208 ready_for_local_init_op=USE_DEFAULT, 209 is_chief=True, 210 init_op=USE_DEFAULT, 211 init_feed_dict=None, 212 local_init_op=USE_DEFAULT, 213 logdir=None, 214 summary_op=USE_DEFAULT, 215 saver=USE_DEFAULT, 216 global_step=USE_DEFAULT, 217 save_summaries_secs=120, 218 save_model_secs=600, 219 recovery_wait_secs=30, 220 stop_grace_secs=120, 221 checkpoint_basename="model.ckpt", 222 session_manager=None, 223 summary_writer=USE_DEFAULT, 224 init_fn=None, 225 local_init_run_options=None): 226 """Create a `Supervisor`. 227 228 Args: 229 graph: A `Graph`. The graph that the model will use. Defaults to the 230 default `Graph`. The supervisor may add operations to the graph before 231 creating a session, but the graph should not be modified by the caller 232 after passing it to the supervisor. 233 ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in 234 `prepare_or_wait_for_session()` to check if the model is ready to use. 235 The model is considered ready if it returns an empty array. Defaults to 236 the tensor returned from `tf.compat.v1.report_uninitialized_variables()` 237 If `None`, the model is not checked for readiness. 238 ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by 239 supervisors in `prepare_or_wait_for_session()` to check if the model is 240 ready to run the local_init_op. The model is considered ready if it 241 returns an empty array. Defaults to `None`. If `None`, the model is not 242 checked for readiness before running local_init_op. 243 is_chief: If True, create a chief supervisor in charge of initializing and 244 restoring the model. If False, create a supervisor that relies on a 245 chief supervisor for inits and restore. 246 init_op: `Operation`. Used by chief supervisors to initialize the model 247 when it can not be recovered. Defaults to an `Operation` that 248 initializes all global variables. If `None`, no initialization is done 249 automatically unless you pass a value for `init_fn`, see below. 250 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 251 This feed dictionary will be used when `init_op` is evaluated. 252 local_init_op: `Operation`. Used by all supervisors to run initializations 253 that should run for every new supervisor instance. By default these are 254 table initializers and initializers for local variables. If `None`, no 255 further per supervisor-instance initialization is done automatically. 256 logdir: A string. Optional path to a directory where to checkpoint the 257 model and log events for the visualizer. Used by chief supervisors. The 258 directory will be created if it does not exist. 259 summary_op: An `Operation` that returns a Summary for the event logs. Used 260 by chief supervisors if a `logdir` was specified. Defaults to the 261 operation returned from summary.merge_all(). If `None`, summaries are 262 not computed automatically. 263 saver: A Saver object. Used by chief supervisors if a `logdir` was 264 specified. Defaults to the saved returned by Saver(). If `None`, the 265 model is not saved automatically. 266 global_step: An integer Tensor of size 1 that counts steps. The value 267 from 'global_step' is used in summaries and checkpoint filenames. 268 Default to the op named 'global_step' in the graph if it exists, is of 269 rank 1, size 1, and of type tf.int32 or tf.int64. If `None` the global 270 step is not recorded in summaries and checkpoint files. Used by chief 271 supervisors if a `logdir` was specified. 272 save_summaries_secs: Number of seconds between the computation of 273 summaries for the event log. Defaults to 120 seconds. Pass 0 to 274 disable summaries. 275 save_model_secs: Number of seconds between the creation of model 276 checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints. 277 recovery_wait_secs: Number of seconds between checks that the model is 278 ready. Used by supervisors when waiting for a chief supervisor to 279 initialize or restore the model. Defaults to 30 seconds. 280 stop_grace_secs: Grace period, in seconds, given to running threads to 281 stop when `stop()` is called. Defaults to 120 seconds. 282 checkpoint_basename: The basename for checkpoint saving. 283 session_manager: `SessionManager`, which manages Session creation and 284 recovery. If it is `None`, a default `SessionManager` will be created 285 with the set of arguments passed in for backwards compatibility. 286 summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` to 287 indicate that no summaries should be written. 288 init_fn: Optional callable used to initialize the model. Called after the 289 optional `init_op` is called. The callable must accept one argument, 290 the session being initialized. 291 local_init_run_options: RunOptions to be passed as the SessionManager 292 local_init_run_options parameter. 293 294 Returns: 295 A `Supervisor`. 296 297 Raises: 298 RuntimeError: If called with eager execution enabled. 299 300 @compatibility(eager) 301 `Supervisor`s are not supported when eager execution is enabled. 302 @end_compatibility 303 """ 304 if context.executing_eagerly(): 305 raise RuntimeError("Supervisors are incompatible with eager execution.") 306 # Set default values of arguments. 307 if graph is None: 308 graph = ops.get_default_graph() 309 with graph.as_default(): 310 self._init_ready_op( 311 ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op) 312 self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict) 313 self._init_local_init_op(local_init_op=local_init_op) 314 self._init_saver(saver=saver) 315 self._init_summary_op(summary_op=summary_op) 316 self._init_global_step(global_step=global_step) 317 self._graph = graph 318 self._meta_graph_def = meta_graph.create_meta_graph_def( 319 graph_def=graph.as_graph_def(add_shapes=True), 320 saver_def=self._saver.saver_def if self._saver else None) 321 self._is_chief = is_chief 322 self._coord = coordinator.Coordinator() 323 self._recovery_wait_secs = recovery_wait_secs 324 self._stop_grace_secs = stop_grace_secs 325 self._init_fn = init_fn 326 self._local_init_run_options = local_init_run_options 327 328 # Set all attributes related to checkpointing and writing events to None. 329 # Afterwards, set them appropriately for chief supervisors, as these are 330 # the only supervisors that can write checkpoints and events. 331 self._logdir = None 332 self._save_summaries_secs = None 333 self._save_model_secs = None 334 self._save_path = None 335 self._summary_writer = None 336 337 if self._is_chief: 338 self._logdir = logdir 339 self._save_summaries_secs = save_summaries_secs 340 self._save_model_secs = save_model_secs 341 if self._logdir: 342 self._save_path = os.path.join(self._logdir, checkpoint_basename) 343 if summary_writer is Supervisor.USE_DEFAULT: 344 if self._logdir: 345 self._summary_writer = _summary.FileWriter(self._logdir) 346 else: 347 self._summary_writer = summary_writer 348 self._graph_added_to_summary = False 349 350 self._init_session_manager(session_manager=session_manager) 351 self._verify_setup() 352 # The graph is not allowed to change anymore. 353 graph.finalize() 354 355 def _init_session_manager(self, session_manager=None): 356 if session_manager is None: 357 self._session_manager = session_manager_mod.SessionManager( 358 local_init_op=self._local_init_op, 359 ready_op=self._ready_op, 360 ready_for_local_init_op=self._ready_for_local_init_op, 361 graph=self._graph, 362 recovery_wait_secs=self._recovery_wait_secs, 363 local_init_run_options=self._local_init_run_options) 364 else: 365 self._session_manager = session_manager 366 367 def _get_first_op_from_collection(self, key): 368 """Returns the first `Operation` from a collection. 369 370 Args: 371 key: A string collection key. 372 373 Returns: 374 The first Op found in a collection, or `None` if the collection is empty. 375 """ 376 try: 377 op_list = ops.get_collection(key) 378 if len(op_list) > 1: 379 logging.info("Found %d %s operations. Returning the first one.", 380 len(op_list), key) 381 if op_list: 382 return op_list[0] 383 except LookupError: 384 pass 385 386 return None 387 388 def _init_ready_op(self, 389 ready_op=USE_DEFAULT, 390 ready_for_local_init_op=USE_DEFAULT): 391 """Initializes ready_op. 392 393 Args: 394 ready_op: `Tensor` to check if the model is initialized. If it's set to 395 USE_DEFAULT, creates an op that checks all the variables are 396 initialized. 397 ready_for_local_init_op: `Tensor` to check if the model is ready to run 398 local_init_op. If it's set to USE_DEFAULT, creates an op that checks all 399 the global variables are initialized. 400 """ 401 if ready_op is Supervisor.USE_DEFAULT: 402 ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP) 403 if ready_op is None: 404 ready_op = variables.report_uninitialized_variables() 405 ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op) 406 self._ready_op = ready_op 407 408 # ready_for_local_init_op defaults to None for backward compatibility 409 if ready_for_local_init_op is Supervisor.USE_DEFAULT: 410 ready_for_local_init_op = self._get_first_op_from_collection( 411 ops.GraphKeys.READY_FOR_LOCAL_INIT_OP) 412 self._ready_for_local_init_op = ready_for_local_init_op 413 414 def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None): 415 """Initializes init_op. 416 417 Args: 418 init_op: `Operation` to initialize the variables. If set to USE_DEFAULT, 419 create an op that initializes all variables and tables. 420 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 421 This feed dictionary will be used when `init_op` is evaluated. 422 """ 423 if init_op is Supervisor.USE_DEFAULT: 424 init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP) 425 if init_op is None: 426 init_op = variables.global_variables_initializer() 427 ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op) 428 self._init_op = init_op 429 self._init_feed_dict = init_feed_dict 430 431 def _init_local_init_op(self, local_init_op=USE_DEFAULT): 432 """Initializes local_init_op. 433 434 Args: 435 local_init_op: `Operation` run for every new supervisor instance. If set 436 to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP 437 collection. If the collection is empty, create an op that initializes 438 all local variables and all tables. 439 """ 440 if local_init_op is Supervisor.USE_DEFAULT: 441 local_init_op = self._get_first_op_from_collection( 442 ops.GraphKeys.LOCAL_INIT_OP) 443 if local_init_op is None: 444 op_list = [ 445 variables.local_variables_initializer(), 446 lookup_ops.tables_initializer() 447 ] 448 if op_list: 449 local_init_op = control_flow_ops.group(*op_list) 450 ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) 451 self._local_init_op = local_init_op 452 453 def _init_saver(self, saver=USE_DEFAULT): 454 """Initializes saver. 455 456 Args: 457 saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all 458 the variables. 459 """ 460 if saver is Supervisor.USE_DEFAULT: 461 saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS) 462 if saver is None and variables.global_variables(): 463 saver = saver_mod.Saver() 464 ops.add_to_collection(ops.GraphKeys.SAVERS, saver) 465 self._saver = saver 466 467 def _init_summary_op(self, summary_op=USE_DEFAULT): 468 """Initializes summary_op. 469 470 Args: 471 summary_op: An Operation that returns a Summary for the event logs. If set 472 to USE_DEFAULT, create an op that merges all the summaries. 473 """ 474 if summary_op is Supervisor.USE_DEFAULT: 475 summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP) 476 if summary_op is None: 477 summary_op = _summary.merge_all() 478 if summary_op is not None: 479 ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op) 480 self._summary_op = summary_op 481 482 def _init_global_step(self, global_step=USE_DEFAULT): 483 """Initializes global_step. 484 485 Args: 486 global_step: An integer Tensor of size 1 that counts steps. If set to 487 USE_DEFAULT, creates global_step tensor. 488 """ 489 if global_step is Supervisor.USE_DEFAULT: 490 global_step = self._get_first_op_from_collection( 491 ops.GraphKeys.GLOBAL_STEP) 492 if global_step is None: 493 global_step = self._default_global_step_tensor() 494 if global_step is not None: 495 ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step) 496 self._global_step = global_step 497 498 @property 499 def is_chief(self): 500 """Return True if this is a chief supervisor. 501 502 Returns: 503 A bool. 504 """ 505 return self._is_chief 506 507 @property 508 def session_manager(self): 509 """Return the SessionManager used by the Supervisor. 510 511 Returns: 512 A SessionManager object. 513 """ 514 return self._session_manager 515 516 @property 517 def coord(self): 518 """Return the Coordinator used by the Supervisor. 519 520 The Coordinator can be useful if you want to run multiple threads 521 during your training. 522 523 Returns: 524 A Coordinator object. 525 """ 526 return self._coord 527 528 @property 529 def init_op(self): 530 """Return the Init Op used by the supervisor. 531 532 Returns: 533 An Op or `None`. 534 """ 535 return self._init_op 536 537 @property 538 def init_feed_dict(self): 539 """Return the feed dictionary used when evaluating the `init_op`. 540 541 Returns: 542 A feed dictionary or `None`. 543 """ 544 return self._init_feed_dict 545 546 @property 547 def ready_op(self): 548 """Return the Ready Op used by the supervisor. 549 550 Returns: 551 An Op or `None`. 552 """ 553 return self._ready_op 554 555 @property 556 def ready_for_local_init_op(self): 557 return self._ready_for_local_init_op 558 559 @property 560 def summary_writer(self): 561 """Return the SummaryWriter used by the chief supervisor. 562 563 Returns: 564 A SummaryWriter. 565 """ 566 return self._summary_writer 567 568 @property 569 def summary_op(self): 570 """Return the Summary Tensor used by the chief supervisor. 571 572 Returns: 573 A string Tensor for the summary or `None`. 574 """ 575 return self._summary_op 576 577 @property 578 def save_summaries_secs(self): 579 """Return the delay between summary computations. 580 581 Returns: 582 A timestamp. 583 """ 584 return self._save_summaries_secs 585 586 @property 587 def global_step(self): 588 """Return the global_step Tensor used by the supervisor. 589 590 Returns: 591 An integer Tensor for the global_step. 592 """ 593 return self._global_step 594 595 @property 596 def saver(self): 597 """Return the Saver used by the supervisor. 598 599 Returns: 600 A Saver object. 601 """ 602 return self._saver 603 604 @property 605 def save_model_secs(self): 606 """Return the delay between checkpoints. 607 608 Returns: 609 A timestamp. 610 """ 611 return self._save_model_secs 612 613 @property 614 def save_path(self): 615 """Return the save path used by the supervisor. 616 617 Returns: 618 A string. 619 """ 620 return self._save_path 621 622 def _write_graph(self): 623 """Writes graph_def to `logdir` and adds it to summary if applicable.""" 624 assert self._is_chief 625 if self._logdir: 626 training_util.write_graph( 627 self._graph.as_graph_def(add_shapes=True), self._logdir, 628 "graph.pbtxt") 629 if self._summary_writer and not self._graph_added_to_summary: 630 self._summary_writer.add_graph(self._graph) 631 self._summary_writer.add_meta_graph(self._meta_graph_def) 632 self._graph_added_to_summary = True 633 634 def start_standard_services(self, sess): 635 """Start the standard services for 'sess'. 636 637 This starts services in the background. The services started depend 638 on the parameters to the constructor and may include: 639 640 - A Summary thread computing summaries every save_summaries_secs. 641 - A Checkpoint thread saving the model every save_model_secs. 642 - A StepCounter thread measure step time. 643 644 Args: 645 sess: A Session. 646 647 Returns: 648 A list of threads that are running the standard services. You can use 649 the Supervisor's Coordinator to join these threads with: 650 sv.coord.Join(<list of threads>) 651 652 Raises: 653 RuntimeError: If called with a non-chief Supervisor. 654 ValueError: If not `logdir` was passed to the constructor as the 655 services need a log directory. 656 """ 657 if not self._is_chief: 658 raise RuntimeError("Only chief supervisor can start standard services. " 659 "Because only chief supervisors can write events.") 660 661 if not self._logdir: 662 logging.warning("Standard services need a 'logdir' " 663 "passed to the SessionManager") 664 return 665 666 if self._global_step is not None and self._summary_writer: 667 # Only add the session log if we keep track of global step. 668 # TensorBoard cannot use START message for purging expired events 669 # if there is no step value. 670 current_step = training_util.global_step(sess, self._global_step) 671 self._summary_writer.add_session_log( 672 SessionLog(status=SessionLog.START), current_step) 673 674 threads = [] 675 if self._save_summaries_secs and self._summary_writer: 676 if self._summary_op is not None: 677 threads.append(SVSummaryThread(self, sess)) 678 if self._global_step is not None: 679 threads.append(SVStepCounterThread(self, sess)) 680 if self.saver and self._save_model_secs: 681 threads.append(SVTimerCheckpointThread(self, sess)) 682 for t in threads: 683 t.start() 684 return threads 685 686 def prepare_or_wait_for_session(self, 687 master="", 688 config=None, 689 wait_for_checkpoint=False, 690 max_wait_secs=7200, 691 start_standard_services=True): 692 """Make sure the model is ready to be used. 693 694 Create a session on 'master', recovering or initializing the model as 695 needed, or wait for a session to be ready. If running as the chief 696 and `start_standard_service` is set to True, also call the session 697 manager to start the standard services. 698 699 Args: 700 master: name of the TensorFlow master to use. See the 701 `tf.compat.v1.Session` constructor for how this is interpreted. 702 config: Optional ConfigProto proto used to configure the session, which is 703 passed as-is to create the session. 704 wait_for_checkpoint: Whether we should wait for the availability of a 705 checkpoint before creating Session. Defaults to False. 706 max_wait_secs: Maximum time to wait for the session to become available. 707 start_standard_services: Whether to start the standard services and the 708 queue runners. 709 710 Returns: 711 A Session object that can be used to drive the model. 712 """ 713 # For users who recreate the session with prepare_or_wait_for_session(), we 714 # need to clear the coordinator's stop_event so that threads managed by the 715 # coordinator can run. 716 self._coord.clear_stop() 717 if self._summary_writer: 718 self._summary_writer.reopen() 719 720 if self._is_chief: 721 sess = self._session_manager.prepare_session( 722 master, 723 init_op=self.init_op, 724 saver=self.saver, 725 checkpoint_dir=self._logdir, 726 wait_for_checkpoint=wait_for_checkpoint, 727 max_wait_secs=max_wait_secs, 728 config=config, 729 init_feed_dict=self._init_feed_dict, 730 init_fn=self._init_fn) 731 self._write_graph() 732 if start_standard_services: 733 logging.info("Starting standard services.") 734 self.start_standard_services(sess) 735 else: 736 sess = self._session_manager.wait_for_session( 737 master, config=config, max_wait_secs=max_wait_secs) 738 if start_standard_services: 739 logging.info("Starting queue runners.") 740 self.start_queue_runners(sess) 741 return sess 742 743 def start_queue_runners(self, sess, queue_runners=None): 744 """Start threads for `QueueRunners`. 745 746 Note that the queue runners collected in the graph key `QUEUE_RUNNERS` 747 are already started automatically when you create a session with the 748 supervisor, so unless you have non-collected queue runners to start 749 you do not need to call this explicitly. 750 751 Args: 752 sess: A `Session`. 753 queue_runners: A list of `QueueRunners`. If not specified, we'll use the 754 list of queue runners gathered in the graph under the key 755 `GraphKeys.QUEUE_RUNNERS`. 756 757 Returns: 758 The list of threads started for the `QueueRunners`. 759 760 Raises: 761 RuntimeError: If called with eager execution enabled. 762 763 @compatibility(eager) 764 Queues are not compatible with eager execution. To ingest data when eager 765 execution is enabled, use the `tf.data` API. 766 @end_compatibility 767 """ 768 if context.executing_eagerly(): 769 raise RuntimeError("Queues are not compatible with eager execution.") 770 if queue_runners is None: 771 queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) 772 threads = [] 773 for qr in queue_runners: 774 threads.extend( 775 qr.create_threads(sess, coord=self._coord, daemon=True, start=True)) 776 return threads 777 778 def loop(self, timer_interval_secs, target, args=None, kwargs=None): 779 """Start a LooperThread that calls a function periodically. 780 781 If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)` 782 repeatedly. Otherwise it calls it every `timer_interval_secs` 783 seconds. The thread terminates when a stop is requested. 784 785 The started thread is added to the list of threads managed by the supervisor 786 so it does not need to be passed to the `stop()` method. 787 788 Args: 789 timer_interval_secs: Number. Time boundaries at which to call `target`. 790 target: A callable object. 791 args: Optional arguments to pass to `target` when calling it. 792 kwargs: Optional keyword arguments to pass to `target` when calling it. 793 794 Returns: 795 The started thread. 796 """ 797 looper = coordinator.LooperThread( 798 self._coord, 799 timer_interval_secs, 800 target=target, 801 args=args, 802 kwargs=kwargs) 803 looper.start() 804 return looper 805 806 def stop(self, 807 threads=None, 808 close_summary_writer=True, 809 ignore_live_threads=False): 810 """Stop the services and the coordinator. 811 812 This does not close the session. 813 814 Args: 815 threads: Optional list of threads to join with the coordinator. If 816 `None`, defaults to the threads running the standard services, the 817 threads started for `QueueRunners`, and the threads started by the 818 `loop()` method. To wait on additional threads, pass the list in this 819 parameter. 820 close_summary_writer: Whether to close the `summary_writer`. Defaults to 821 `True` if the summary writer was created by the supervisor, `False` 822 otherwise. 823 ignore_live_threads: If `True` ignores threads that remain running after a 824 grace period when joining threads via the coordinator, instead of 825 raising a RuntimeError. 826 """ 827 self._coord.request_stop() 828 try: 829 # coord.join() re-raises the first reported exception; the "finally" 830 # block ensures that we clean up whether or not an exception was 831 # reported. 832 self._coord.join( 833 threads, 834 stop_grace_period_secs=self._stop_grace_secs, 835 ignore_live_threads=ignore_live_threads) 836 finally: 837 # Close the writer last, in case one of the running threads was using it. 838 if close_summary_writer and self._summary_writer: 839 # Stop messages are not logged with event.step, 840 # since the session may have already terminated. 841 self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP)) 842 self._summary_writer.close() 843 self._graph_added_to_summary = False 844 845 def request_stop(self, ex=None): 846 """Request that the coordinator stop the threads. 847 848 See `Coordinator.request_stop()`. 849 850 Args: 851 ex: Optional `Exception`, or Python `exc_info` tuple as returned by 852 `sys.exc_info()`. If this is the first call to `request_stop()` the 853 corresponding exception is recorded and re-raised from `join()`. 854 """ 855 self._coord.request_stop(ex=ex) 856 857 def should_stop(self): 858 """Check if the coordinator was told to stop. 859 860 See `Coordinator.should_stop()`. 861 862 Returns: 863 True if the coordinator was told to stop, False otherwise. 864 """ 865 return self._coord.should_stop() 866 867 def stop_on_exception(self): 868 """Context handler to stop the supervisor when an exception is raised. 869 870 See `Coordinator.stop_on_exception()`. 871 872 Returns: 873 A context handler. 874 """ 875 return self._coord.stop_on_exception() 876 877 def wait_for_stop(self): 878 """Block waiting for the coordinator to stop.""" 879 self._coord.wait_for_stop() 880 881 def summary_computed(self, sess, summary, global_step=None): 882 """Indicate that a summary was computed. 883 884 Args: 885 sess: A `Session` object. 886 summary: A Summary proto, or a string holding a serialized summary proto. 887 global_step: Int. global step this summary is associated with. If `None`, 888 it will try to fetch the current step. 889 890 Raises: 891 TypeError: if 'summary' is not a Summary proto or a string. 892 RuntimeError: if the Supervisor was created without a `logdir`. 893 """ 894 if not self._summary_writer: 895 raise RuntimeError("Writing a summary requires a summary writer.") 896 if global_step is None and self.global_step is not None: 897 global_step = training_util.global_step(sess, self.global_step) 898 self._summary_writer.add_summary(summary, global_step) 899 900 def _default_global_step_tensor(self): 901 """Returns the global_step from the default graph. 902 903 Returns: 904 The global step `Tensor` or `None`. 905 """ 906 try: 907 gs = ops.get_default_graph().get_tensor_by_name("global_step:0") 908 if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: 909 return gs 910 else: 911 logging.warning("Found 'global_step' is not an int type: %s", gs.dtype) 912 return None 913 except KeyError: 914 return None 915 916 def _verify_setup(self): 917 """Check that all is good. 918 919 Raises: 920 ValueError: If something is not good. 921 """ 922 # Not running as chief means that replicas are used. 923 # In that case all Variables must have their device set. 924 if not self._is_chief: 925 for op in self._graph.get_operations(): 926 if op.type in ["Variable", "VariableV2"] and not op.device: 927 raise ValueError("When using replicas, all Variables must have " 928 "their device set: %s" % op) 929 930 # pylint: disable=g-doc-return-or-yield,broad-except 931 @contextlib.contextmanager 932 def managed_session(self, 933 master="", 934 config=None, 935 start_standard_services=True, 936 close_summary_writer=True): 937 """Returns a context manager for a managed session. 938 939 This context manager creates and automatically recovers a session. It 940 optionally starts the standard services that handle checkpoints and 941 summaries. It monitors exceptions raised from the `with` block or from the 942 services and stops the supervisor as needed. 943 944 The context manager is typically used as follows: 945 946 ```python 947 def train(): 948 sv = tf.compat.v1.train.Supervisor(...) 949 with sv.managed_session(<master>) as sess: 950 for step in range(..): 951 if sv.should_stop(): 952 break 953 sess.run(<my training op>) 954 ...do other things needed at each training step... 955 ``` 956 957 An exception raised from the `with` block or one of the service threads is 958 raised again when the block exits. This is done after stopping all threads 959 and closing the session. For example, an `AbortedError` exception, raised 960 in case of preemption of one of the workers in a distributed model, is 961 raised again when the block exits. 962 963 If you want to retry the training loop in case of preemption you can do it 964 as follows: 965 966 ```python 967 def main(...): 968 while True 969 try: 970 train() 971 except tf.errors.Aborted: 972 pass 973 ``` 974 975 As a special case, exceptions used for control flow, such as 976 `OutOfRangeError` which reports that input queues are exhausted, are not 977 raised again from the `with` block: they indicate a clean termination of 978 the training loop and are considered normal termination. 979 980 Args: 981 master: name of the TensorFlow master to use. See the 982 `tf.compat.v1.Session` constructor for how this is interpreted. 983 config: Optional `ConfigProto` proto used to configure the session. Passed 984 as-is to create the session. 985 start_standard_services: Whether to start the standard services, such as 986 checkpoint, summary and step counter. 987 close_summary_writer: Whether to close the summary writer when closing the 988 session. Defaults to True. 989 990 Returns: 991 A context manager that yields a `Session` restored from the latest 992 checkpoint or initialized from scratch if not checkpoint exists. The 993 session is closed when the `with` block exits. 994 """ 995 try: 996 sess = self.prepare_or_wait_for_session( 997 master=master, 998 config=config, 999 start_standard_services=start_standard_services) 1000 yield sess 1001 except Exception as e: 1002 self.request_stop(e) 1003 finally: 1004 try: 1005 # Request all the threads to stop and wait for them to do so. Any 1006 # exception raised by the threads is raised again from stop(). 1007 # Passing stop_grace_period_secs is for blocked enqueue/dequeue 1008 # threads which are not checking for `should_stop()`. They 1009 # will be stopped when we close the session further down. 1010 self.stop(close_summary_writer=close_summary_writer) 1011 finally: 1012 # Close the session to finish up all pending calls. We do not care 1013 # about exceptions raised when closing. This takes care of 1014 # blocked enqueue/dequeue calls. 1015 try: 1016 sess.close() 1017 except Exception: 1018 # Silently ignore exceptions raised by close(). 1019 pass 1020 1021 # pylint: enable=g-doc-return-or-yield,broad-except 1022 1023 1024class SVSummaryThread(coordinator.LooperThread): 1025 """A thread to save summaries on a timer.""" 1026 1027 def __init__(self, sv, sess): 1028 """Create a SVSummaryThread. 1029 1030 Args: 1031 sv: A `Supervisor`. 1032 sess: A `Session`. 1033 """ 1034 super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs) 1035 self._sv = sv 1036 self._sess = sess 1037 1038 def run_loop(self): 1039 if self._sv.global_step is not None: 1040 summary_strs, global_step = self._sess.run( 1041 [self._sv.summary_op, self._sv.global_step]) 1042 else: 1043 summary_strs = self._sess.run(self._sv.summary_op) 1044 global_step = None 1045 if self._sv.summary_writer: 1046 logging.info("Recording summary at step %s.", global_step) 1047 self._sv.summary_writer.add_summary(summary_strs, global_step) 1048 1049 1050class SVStepCounterThread(coordinator.LooperThread): 1051 """Threads to count steps and measure their duration.""" 1052 1053 def __init__(self, sv, sess, step_counter=None): 1054 """Create a `SVStepCounterThread`. 1055 1056 Args: 1057 sv: A `Supervisor`. 1058 sess: A `Session`. 1059 step_counter: A `Tensor` holding the step counter. By defaults, it uses 1060 sv.global_step. 1061 """ 1062 super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs) 1063 self._sv = sv 1064 self._sess = sess 1065 self._last_time = 0.0 1066 self._last_step = 0 1067 step_counter = sv.global_step if step_counter is None else step_counter 1068 self._step_counter = step_counter 1069 self._summary_tag = "%s/sec" % self._step_counter.op.name 1070 1071 def start_loop(self): 1072 self._last_time = time.time() 1073 self._last_step = training_util.global_step(self._sess, self._step_counter) 1074 1075 def run_loop(self): 1076 # Count the steps. 1077 current_step = training_util.global_step(self._sess, self._step_counter) 1078 added_steps = current_step - self._last_step 1079 self._last_step = current_step 1080 # Measure the elapsed time. 1081 current_time = time.time() 1082 elapsed_time = current_time - self._last_time 1083 self._last_time = current_time 1084 # Reports the number of steps done per second 1085 if elapsed_time > 0.: 1086 steps_per_sec = added_steps / elapsed_time 1087 else: 1088 steps_per_sec = float("inf") 1089 summary = Summary(value=[ 1090 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) 1091 ]) 1092 if self._sv.summary_writer: 1093 self._sv.summary_writer.add_summary(summary, current_step) 1094 logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag, 1095 steps_per_sec) 1096 1097 1098class SVTimerCheckpointThread(coordinator.LooperThread): 1099 """A thread to checkpoint on a timer.""" 1100 1101 def __init__(self, sv, sess): 1102 """Create a `SVTimerCheckpointThread`. 1103 1104 Args: 1105 sv: A `Supervisor`. 1106 sess: A `Session`. 1107 """ 1108 super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs) 1109 self._sv = sv 1110 self._sess = sess 1111 1112 def run_loop(self): 1113 logging.info("Saving checkpoint to path %s", self._sv.save_path) 1114 self._sv.saver.save( 1115 self._sess, self._sv.save_path, global_step=self._sv.global_step) 1116 if self._sv.summary_writer and self._sv.global_step is not None: 1117 current_step = training_util.global_step(self._sess, self._sv.global_step) 1118 self._sv.summary_writer.add_session_log( 1119 SessionLog( 1120 status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path), 1121 current_step) 1122 1123 1124# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly. 1125setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session) 1126setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners) 1127setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services) 1128setattr(Supervisor, "Stop", Supervisor.stop) 1129setattr(Supervisor, "RequestStop", Supervisor.request_stop) 1130setattr(Supervisor, "Loop", Supervisor.loop) 1131setattr(Supervisor, "ShouldStop", Supervisor.should_stop) 1132setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception) 1133setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop) 1134setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed) 1135