• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and computes summaries."""
16import contextlib
17import os
18import time
19
20from tensorflow.core.framework.summary_pb2 import Summary
21from tensorflow.core.util.event_pb2 import SessionLog
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import meta_graph
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import lookup_ops
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.summary import summary as _summary
31from tensorflow.python.training import coordinator
32from tensorflow.python.training import saver as saver_mod
33from tensorflow.python.training import session_manager as session_manager_mod
34from tensorflow.python.training import training_util
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.tf_export import tf_export
37
38
39@tf_export(v1=["train.Supervisor"])
40class Supervisor:
41  """A training helper that checkpoints models and computes summaries.
42
43  This class is deprecated. Please use
44  `tf.compat.v1.train.MonitoredTrainingSession` instead.
45
46  The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
47  and a `SessionManager` that takes care of common needs of TensorFlow
48  training programs.
49
50  #### Use for a single program
51
52  ```python
53  with tf.Graph().as_default():
54    ...add operations to the graph...
55    # Create a Supervisor that will checkpoint the model in '/tmp/mydir'.
56    sv = Supervisor(logdir='/tmp/mydir')
57    # Get a TensorFlow session managed by the supervisor.
58    with sv.managed_session(FLAGS.master) as sess:
59      # Use the session to train the graph.
60      while not sv.should_stop():
61        sess.run(<my_train_op>)
62  ```
63
64  Within the `with sv.managed_session()` block all variables in the graph have
65  been initialized.  In addition, a few services have been started to
66  checkpoint the model and add summaries to the event log.
67
68  If the program crashes and is restarted, the managed session automatically
69  reinitialize variables from the most recent checkpoint.
70
71  The supervisor is notified of any exception raised by one of the services.
72  After an exception is raised, `should_stop()` returns `True`.  In that case
73  the training loop should also stop.  This is why the training loop has to
74  check for `sv.should_stop()`.
75
76  Exceptions that indicate that the training inputs have been exhausted,
77  `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True`
78  but are not re-raised from the `with` block: they indicate a normal
79  termination.
80
81  #### Use for multiple replicas
82
83  To train with replicas you deploy the same program in a `Cluster`.
84  One of the tasks must be identified as the *chief*: the task that handles
85  initialization, checkpoints, summaries, and recovery.  The other tasks
86  depend on the *chief* for these services.
87
88  The only change you have to do to the single program code is to indicate
89  if the program is running as the *chief*.
90
91  ```python
92  # Choose a task as the chief. This could be based on server_def.task_index,
93  # or job_def.name, or job_def.tasks. It's entirely up to the end user.
94  # But there can be only one *chief*.
95  is_chief = (server_def.task_index == 0)
96  server = tf.distribute.Server(server_def)
97
98  with tf.Graph().as_default():
99    ...add operations to the graph...
100    # Create a Supervisor that uses log directory on a shared file system.
101    # Indicate if you are the 'chief'
102    sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief)
103    # Get a Session in a TensorFlow server on the cluster.
104    with sv.managed_session(server.target) as sess:
105      # Use the session to train the graph.
106      while not sv.should_stop():
107        sess.run(<my_train_op>)
108  ```
109
110  In the *chief* task, the `Supervisor` works exactly as in the first example
111  above.  In the other tasks `sv.managed_session()` waits for the Model to have
112  been initialized before returning a session to the training code.  The
113  non-chief tasks depend on the chief task for initializing the model.
114
115  If one of the tasks crashes and restarts, `managed_session()`
116  checks if the Model is initialized.  If yes, it just creates a session and
117  returns it to the training code that proceeds normally.  If the model needs
118  to be initialized, the chief task takes care of reinitializing it; the other
119  tasks just wait for the model to have been initialized.
120
121  NOTE: This modified program still works fine as a single program.
122  The single program marks itself as the chief.
123
124  #### What `master` string to use
125
126  Whether you are running on your machine or in the cluster you can use the
127  following values for the --master flag:
128
129  * Specifying `''` requests an in-process session that does not use RPC.
130
131  * Specifying `'local'` requests a session that uses the RPC-based
132    "Master interface" to run TensorFlow programs. See
133    `tf.train.Server.create_local_server` for
134    details.
135
136  * Specifying `'grpc://hostname:port'` requests a session that uses
137    the RPC interface to a specific host, and also allows the in-process
138    master to access remote tensorflow workers. Often, it is
139    appropriate to pass `server.target` (for some `tf.distribute.Server`
140    named `server).
141
142  #### Advanced use
143
144  ##### Launching additional services
145
146  `managed_session()` launches the Checkpoint and Summary services (threads).
147  If you need more services to run you can simply launch them in the block
148  controlled by `managed_session()`.
149
150  Example: Start a thread to print losses.  We want this thread to run
151  every 60 seconds, so we launch it with `sv.loop()`.
152
153  ```python
154  ...
155  sv = Supervisor(logdir='/tmp/mydir')
156  with sv.managed_session(FLAGS.master) as sess:
157    sv.loop(60, print_loss, (sess, ))
158    while not sv.should_stop():
159      sess.run(my_train_op)
160  ```
161
162  ##### Launching fewer services
163
164  `managed_session()` launches the "summary" and "checkpoint" threads which use
165  either the optionally `summary_op` and `saver` passed to the constructor, or
166  default ones created automatically by the supervisor.  If you want to run
167  your own summary and checkpointing logic, disable these services by passing
168  `None` to the `summary_op` and `saver` parameters.
169
170  Example: Create summaries manually every 100 steps in the chief.
171
172  ```python
173  # Create a Supervisor with no automatic summaries.
174  sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None)
175  # As summary_op was None, managed_session() does not start the
176  # summary thread.
177  with sv.managed_session(FLAGS.master) as sess:
178    for step in range(1000000):
179      if sv.should_stop():
180        break
181      if is_chief and step % 100 == 0:
182        # Create the summary every 100 chief steps.
183        sv.summary_computed(sess, sess.run(my_summary_op))
184      else:
185        # Train normally
186        sess.run(my_train_op)
187  ```
188
189  ##### Custom model initialization
190
191  `managed_session()` only supports initializing the model by running an
192  `init_op` or restoring from the latest checkpoint.  If you have special
193  initialization needs, see how to specify a `local_init_op` when creating the
194  supervisor.  You can also use the `SessionManager` directly to create a
195  session and check if it could be initialized automatically.
196  """
197
198  # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver',
199  # and 'global_step' parameters of Supervisor.__init__() to indicate that
200  # the default behavior should be used.
201  USE_DEFAULT = 0
202
203  @deprecation.deprecated(None,
204                          "Please switch to tf.train.MonitoredTrainingSession")
205  def __init__(self,
206               graph=None,
207               ready_op=USE_DEFAULT,
208               ready_for_local_init_op=USE_DEFAULT,
209               is_chief=True,
210               init_op=USE_DEFAULT,
211               init_feed_dict=None,
212               local_init_op=USE_DEFAULT,
213               logdir=None,
214               summary_op=USE_DEFAULT,
215               saver=USE_DEFAULT,
216               global_step=USE_DEFAULT,
217               save_summaries_secs=120,
218               save_model_secs=600,
219               recovery_wait_secs=30,
220               stop_grace_secs=120,
221               checkpoint_basename="model.ckpt",
222               session_manager=None,
223               summary_writer=USE_DEFAULT,
224               init_fn=None,
225               local_init_run_options=None):
226    """Create a `Supervisor`.
227
228    Args:
229      graph: A `Graph`.  The graph that the model will use.  Defaults to the
230        default `Graph`.  The supervisor may add operations to the graph before
231        creating a session, but the graph should not be modified by the caller
232        after passing it to the supervisor.
233      ready_op: 1-D string `Tensor`.  This tensor is evaluated by supervisors in
234        `prepare_or_wait_for_session()` to check if the model is ready to use.
235        The model is considered ready if it returns an empty array.  Defaults to
236        the tensor returned from `tf.compat.v1.report_uninitialized_variables()`
237        If `None`, the model is not checked for readiness.
238      ready_for_local_init_op: 1-D string `Tensor`.  This tensor is evaluated by
239        supervisors in `prepare_or_wait_for_session()` to check if the model is
240        ready to run the local_init_op. The model is considered ready if it
241        returns an empty array. Defaults to `None`. If `None`, the model is not
242        checked for readiness before running local_init_op.
243      is_chief: If True, create a chief supervisor in charge of initializing and
244        restoring the model.  If False, create a supervisor that relies on a
245        chief supervisor for inits and restore.
246      init_op: `Operation`.  Used by chief supervisors to initialize the model
247        when it can not be recovered.  Defaults to an `Operation` that
248        initializes all global variables.  If `None`, no initialization is done
249        automatically unless you pass a value for `init_fn`, see below.
250      init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
251        This feed dictionary will be used when `init_op` is evaluated.
252      local_init_op: `Operation`. Used by all supervisors to run initializations
253        that should run for every new supervisor instance. By default these are
254        table initializers and initializers for local variables. If `None`, no
255        further per supervisor-instance initialization is done automatically.
256      logdir: A string.  Optional path to a directory where to checkpoint the
257        model and log events for the visualizer.  Used by chief supervisors. The
258        directory will be created if it does not exist.
259      summary_op: An `Operation` that returns a Summary for the event logs. Used
260        by chief supervisors if a `logdir` was specified.  Defaults to the
261        operation returned from summary.merge_all().  If `None`, summaries are
262        not computed automatically.
263      saver: A Saver object.  Used by chief supervisors if a `logdir` was
264        specified.  Defaults to the saved returned by Saver(). If `None`, the
265        model is not saved automatically.
266      global_step: An integer Tensor of size 1 that counts steps.  The value
267        from 'global_step' is used in summaries and checkpoint filenames.
268        Default to the op named 'global_step' in the graph if it exists, is of
269        rank 1, size 1, and of type tf.int32 or tf.int64.  If `None` the global
270        step is not recorded in summaries and checkpoint files.  Used by chief
271        supervisors if a `logdir` was specified.
272      save_summaries_secs: Number of seconds between the computation of
273        summaries for the event log.  Defaults to 120 seconds.  Pass 0 to
274        disable summaries.
275      save_model_secs: Number of seconds between the creation of model
276        checkpoints.  Defaults to 600 seconds.  Pass 0 to disable checkpoints.
277      recovery_wait_secs: Number of seconds between checks that the model is
278        ready.  Used by supervisors when waiting for a chief supervisor to
279        initialize or restore the model.  Defaults to 30 seconds.
280      stop_grace_secs: Grace period, in seconds, given to running threads to
281        stop when `stop()` is called.  Defaults to 120 seconds.
282      checkpoint_basename: The basename for checkpoint saving.
283      session_manager: `SessionManager`, which manages Session creation and
284        recovery. If it is `None`, a default `SessionManager` will be created
285        with the set of arguments passed in for backwards compatibility.
286      summary_writer: `SummaryWriter` to use or `USE_DEFAULT`.  Can be `None` to
287        indicate that no summaries should be written.
288      init_fn: Optional callable used to initialize the model. Called after the
289        optional `init_op` is called.  The callable must accept one argument,
290        the session being initialized.
291      local_init_run_options: RunOptions to be passed as the SessionManager
292        local_init_run_options parameter.
293
294    Returns:
295      A `Supervisor`.
296
297    Raises:
298      RuntimeError: If called with eager execution enabled.
299
300    @compatibility(eager)
301    `Supervisor`s are not supported when eager execution is enabled.
302    @end_compatibility
303    """
304    if context.executing_eagerly():
305      raise RuntimeError("Supervisors are incompatible with eager execution.")
306    # Set default values of arguments.
307    if graph is None:
308      graph = ops.get_default_graph()
309    with graph.as_default():
310      self._init_ready_op(
311          ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op)
312      self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict)
313      self._init_local_init_op(local_init_op=local_init_op)
314      self._init_saver(saver=saver)
315      self._init_summary_op(summary_op=summary_op)
316      self._init_global_step(global_step=global_step)
317    self._graph = graph
318    self._meta_graph_def = meta_graph.create_meta_graph_def(
319        graph_def=graph.as_graph_def(add_shapes=True),
320        saver_def=self._saver.saver_def if self._saver else None)
321    self._is_chief = is_chief
322    self._coord = coordinator.Coordinator()
323    self._recovery_wait_secs = recovery_wait_secs
324    self._stop_grace_secs = stop_grace_secs
325    self._init_fn = init_fn
326    self._local_init_run_options = local_init_run_options
327
328    # Set all attributes related to checkpointing and writing events to None.
329    # Afterwards, set them appropriately for chief supervisors, as these are
330    # the only supervisors that can write checkpoints and events.
331    self._logdir = None
332    self._save_summaries_secs = None
333    self._save_model_secs = None
334    self._save_path = None
335    self._summary_writer = None
336
337    if self._is_chief:
338      self._logdir = logdir
339      self._save_summaries_secs = save_summaries_secs
340      self._save_model_secs = save_model_secs
341      if self._logdir:
342        self._save_path = os.path.join(self._logdir, checkpoint_basename)
343      if summary_writer is Supervisor.USE_DEFAULT:
344        if self._logdir:
345          self._summary_writer = _summary.FileWriter(self._logdir)
346      else:
347        self._summary_writer = summary_writer
348      self._graph_added_to_summary = False
349
350    self._init_session_manager(session_manager=session_manager)
351    self._verify_setup()
352    # The graph is not allowed to change anymore.
353    graph.finalize()
354
355  def _init_session_manager(self, session_manager=None):
356    if session_manager is None:
357      self._session_manager = session_manager_mod.SessionManager(
358          local_init_op=self._local_init_op,
359          ready_op=self._ready_op,
360          ready_for_local_init_op=self._ready_for_local_init_op,
361          graph=self._graph,
362          recovery_wait_secs=self._recovery_wait_secs,
363          local_init_run_options=self._local_init_run_options)
364    else:
365      self._session_manager = session_manager
366
367  def _get_first_op_from_collection(self, key):
368    """Returns the first `Operation` from a collection.
369
370    Args:
371      key: A string collection key.
372
373    Returns:
374      The first Op found in a collection, or `None` if the collection is empty.
375    """
376    try:
377      op_list = ops.get_collection(key)
378      if len(op_list) > 1:
379        logging.info("Found %d %s operations. Returning the first one.",
380                     len(op_list), key)
381      if op_list:
382        return op_list[0]
383    except LookupError:
384      pass
385
386    return None
387
388  def _init_ready_op(self,
389                     ready_op=USE_DEFAULT,
390                     ready_for_local_init_op=USE_DEFAULT):
391    """Initializes ready_op.
392
393    Args:
394      ready_op: `Tensor` to check if the model is initialized. If it's set to
395        USE_DEFAULT, creates an op that checks all the variables are
396        initialized.
397      ready_for_local_init_op: `Tensor` to check if the model is ready to run
398        local_init_op. If it's set to USE_DEFAULT, creates an op that checks all
399        the global variables are initialized.
400    """
401    if ready_op is Supervisor.USE_DEFAULT:
402      ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
403      if ready_op is None:
404        ready_op = variables.report_uninitialized_variables()
405        ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
406    self._ready_op = ready_op
407
408    # ready_for_local_init_op defaults to None for backward compatibility
409    if ready_for_local_init_op is Supervisor.USE_DEFAULT:
410      ready_for_local_init_op = self._get_first_op_from_collection(
411          ops.GraphKeys.READY_FOR_LOCAL_INIT_OP)
412    self._ready_for_local_init_op = ready_for_local_init_op
413
414  def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None):
415    """Initializes init_op.
416
417    Args:
418      init_op: `Operation` to initialize the variables. If set to USE_DEFAULT,
419        create an op that initializes all variables and tables.
420      init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
421        This feed dictionary will be used when `init_op` is evaluated.
422    """
423    if init_op is Supervisor.USE_DEFAULT:
424      init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP)
425      if init_op is None:
426        init_op = variables.global_variables_initializer()
427        ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op)
428    self._init_op = init_op
429    self._init_feed_dict = init_feed_dict
430
431  def _init_local_init_op(self, local_init_op=USE_DEFAULT):
432    """Initializes local_init_op.
433
434    Args:
435      local_init_op: `Operation` run for every new supervisor instance. If set
436        to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
437        collection. If the collection is empty, create an op that initializes
438        all local variables and all tables.
439    """
440    if local_init_op is Supervisor.USE_DEFAULT:
441      local_init_op = self._get_first_op_from_collection(
442          ops.GraphKeys.LOCAL_INIT_OP)
443      if local_init_op is None:
444        op_list = [
445            variables.local_variables_initializer(),
446            lookup_ops.tables_initializer()
447        ]
448        if op_list:
449          local_init_op = control_flow_ops.group(*op_list)
450          ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
451    self._local_init_op = local_init_op
452
453  def _init_saver(self, saver=USE_DEFAULT):
454    """Initializes saver.
455
456    Args:
457      saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all
458        the variables.
459    """
460    if saver is Supervisor.USE_DEFAULT:
461      saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
462      if saver is None and variables.global_variables():
463        saver = saver_mod.Saver()
464        ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
465    self._saver = saver
466
467  def _init_summary_op(self, summary_op=USE_DEFAULT):
468    """Initializes summary_op.
469
470    Args:
471      summary_op: An Operation that returns a Summary for the event logs. If set
472        to USE_DEFAULT, create an op that merges all the summaries.
473    """
474    if summary_op is Supervisor.USE_DEFAULT:
475      summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
476      if summary_op is None:
477        summary_op = _summary.merge_all()
478        if summary_op is not None:
479          ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
480    self._summary_op = summary_op
481
482  def _init_global_step(self, global_step=USE_DEFAULT):
483    """Initializes global_step.
484
485    Args:
486      global_step: An integer Tensor of size 1 that counts steps. If set to
487        USE_DEFAULT, creates global_step tensor.
488    """
489    if global_step is Supervisor.USE_DEFAULT:
490      global_step = self._get_first_op_from_collection(
491          ops.GraphKeys.GLOBAL_STEP)
492      if global_step is None:
493        global_step = self._default_global_step_tensor()
494        if global_step is not None:
495          ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step)
496    self._global_step = global_step
497
498  @property
499  def is_chief(self):
500    """Return True if this is a chief supervisor.
501
502    Returns:
503      A bool.
504    """
505    return self._is_chief
506
507  @property
508  def session_manager(self):
509    """Return the SessionManager used by the Supervisor.
510
511    Returns:
512      A SessionManager object.
513    """
514    return self._session_manager
515
516  @property
517  def coord(self):
518    """Return the Coordinator used by the Supervisor.
519
520    The Coordinator can be useful if you want to run multiple threads
521    during your training.
522
523    Returns:
524      A Coordinator object.
525    """
526    return self._coord
527
528  @property
529  def init_op(self):
530    """Return the Init Op used by the supervisor.
531
532    Returns:
533      An Op or `None`.
534    """
535    return self._init_op
536
537  @property
538  def init_feed_dict(self):
539    """Return the feed dictionary used when evaluating the `init_op`.
540
541    Returns:
542      A feed dictionary or `None`.
543    """
544    return self._init_feed_dict
545
546  @property
547  def ready_op(self):
548    """Return the Ready Op used by the supervisor.
549
550    Returns:
551      An Op or `None`.
552    """
553    return self._ready_op
554
555  @property
556  def ready_for_local_init_op(self):
557    return self._ready_for_local_init_op
558
559  @property
560  def summary_writer(self):
561    """Return the SummaryWriter used by the chief supervisor.
562
563    Returns:
564      A SummaryWriter.
565    """
566    return self._summary_writer
567
568  @property
569  def summary_op(self):
570    """Return the Summary Tensor used by the chief supervisor.
571
572    Returns:
573      A string Tensor for the summary or `None`.
574    """
575    return self._summary_op
576
577  @property
578  def save_summaries_secs(self):
579    """Return the delay between summary computations.
580
581    Returns:
582      A timestamp.
583    """
584    return self._save_summaries_secs
585
586  @property
587  def global_step(self):
588    """Return the global_step Tensor used by the supervisor.
589
590    Returns:
591      An integer Tensor for the global_step.
592    """
593    return self._global_step
594
595  @property
596  def saver(self):
597    """Return the Saver used by the supervisor.
598
599    Returns:
600      A Saver object.
601    """
602    return self._saver
603
604  @property
605  def save_model_secs(self):
606    """Return the delay between checkpoints.
607
608    Returns:
609      A timestamp.
610    """
611    return self._save_model_secs
612
613  @property
614  def save_path(self):
615    """Return the save path used by the supervisor.
616
617    Returns:
618      A string.
619    """
620    return self._save_path
621
622  def _write_graph(self):
623    """Writes graph_def to `logdir` and adds it to summary if applicable."""
624    assert self._is_chief
625    if self._logdir:
626      training_util.write_graph(
627          self._graph.as_graph_def(add_shapes=True), self._logdir,
628          "graph.pbtxt")
629    if self._summary_writer and not self._graph_added_to_summary:
630      self._summary_writer.add_graph(self._graph)
631      self._summary_writer.add_meta_graph(self._meta_graph_def)
632      self._graph_added_to_summary = True
633
634  def start_standard_services(self, sess):
635    """Start the standard services for 'sess'.
636
637    This starts services in the background.  The services started depend
638    on the parameters to the constructor and may include:
639
640      - A Summary thread computing summaries every save_summaries_secs.
641      - A Checkpoint thread saving the model every save_model_secs.
642      - A StepCounter thread measure step time.
643
644    Args:
645      sess: A Session.
646
647    Returns:
648      A list of threads that are running the standard services.  You can use
649      the Supervisor's Coordinator to join these threads with:
650        sv.coord.Join(<list of threads>)
651
652    Raises:
653      RuntimeError: If called with a non-chief Supervisor.
654      ValueError: If not `logdir` was passed to the constructor as the
655        services need a log directory.
656    """
657    if not self._is_chief:
658      raise RuntimeError("Only chief supervisor can start standard services. "
659                         "Because only chief supervisors can write events.")
660
661    if not self._logdir:
662      logging.warning("Standard services need a 'logdir' "
663                      "passed to the SessionManager")
664      return
665
666    if self._global_step is not None and self._summary_writer:
667      # Only add the session log if we keep track of global step.
668      # TensorBoard cannot use START message for purging expired events
669      # if there is no step value.
670      current_step = training_util.global_step(sess, self._global_step)
671      self._summary_writer.add_session_log(
672          SessionLog(status=SessionLog.START), current_step)
673
674    threads = []
675    if self._save_summaries_secs and self._summary_writer:
676      if self._summary_op is not None:
677        threads.append(SVSummaryThread(self, sess))
678      if self._global_step is not None:
679        threads.append(SVStepCounterThread(self, sess))
680    if self.saver and self._save_model_secs:
681      threads.append(SVTimerCheckpointThread(self, sess))
682    for t in threads:
683      t.start()
684    return threads
685
686  def prepare_or_wait_for_session(self,
687                                  master="",
688                                  config=None,
689                                  wait_for_checkpoint=False,
690                                  max_wait_secs=7200,
691                                  start_standard_services=True):
692    """Make sure the model is ready to be used.
693
694    Create a session on 'master', recovering or initializing the model as
695    needed, or wait for a session to be ready.  If running as the chief
696    and `start_standard_service` is set to True, also call the session
697    manager to start the standard services.
698
699    Args:
700      master: name of the TensorFlow master to use.  See the
701        `tf.compat.v1.Session` constructor for how this is interpreted.
702      config: Optional ConfigProto proto used to configure the session, which is
703        passed as-is to create the session.
704      wait_for_checkpoint: Whether we should wait for the availability of a
705        checkpoint before creating Session. Defaults to False.
706      max_wait_secs: Maximum time to wait for the session to become available.
707      start_standard_services: Whether to start the standard services and the
708        queue runners.
709
710    Returns:
711      A Session object that can be used to drive the model.
712    """
713    # For users who recreate the session with prepare_or_wait_for_session(), we
714    # need to clear the coordinator's stop_event so that threads managed by the
715    # coordinator can run.
716    self._coord.clear_stop()
717    if self._summary_writer:
718      self._summary_writer.reopen()
719
720    if self._is_chief:
721      sess = self._session_manager.prepare_session(
722          master,
723          init_op=self.init_op,
724          saver=self.saver,
725          checkpoint_dir=self._logdir,
726          wait_for_checkpoint=wait_for_checkpoint,
727          max_wait_secs=max_wait_secs,
728          config=config,
729          init_feed_dict=self._init_feed_dict,
730          init_fn=self._init_fn)
731      self._write_graph()
732      if start_standard_services:
733        logging.info("Starting standard services.")
734        self.start_standard_services(sess)
735    else:
736      sess = self._session_manager.wait_for_session(
737          master, config=config, max_wait_secs=max_wait_secs)
738    if start_standard_services:
739      logging.info("Starting queue runners.")
740      self.start_queue_runners(sess)
741    return sess
742
743  def start_queue_runners(self, sess, queue_runners=None):
744    """Start threads for `QueueRunners`.
745
746    Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
747    are already started automatically when you create a session with the
748    supervisor, so unless you have non-collected queue runners to start
749    you do not need to call this explicitly.
750
751    Args:
752      sess: A `Session`.
753      queue_runners: A list of `QueueRunners`. If not specified, we'll use the
754        list of queue runners gathered in the graph under the key
755        `GraphKeys.QUEUE_RUNNERS`.
756
757    Returns:
758      The list of threads started for the `QueueRunners`.
759
760    Raises:
761      RuntimeError: If called with eager execution enabled.
762
763    @compatibility(eager)
764    Queues are not compatible with eager execution. To ingest data when eager
765    execution is enabled, use the `tf.data` API.
766    @end_compatibility
767    """
768    if context.executing_eagerly():
769      raise RuntimeError("Queues are not compatible with eager execution.")
770    if queue_runners is None:
771      queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
772    threads = []
773    for qr in queue_runners:
774      threads.extend(
775          qr.create_threads(sess, coord=self._coord, daemon=True, start=True))
776    return threads
777
778  def loop(self, timer_interval_secs, target, args=None, kwargs=None):
779    """Start a LooperThread that calls a function periodically.
780
781    If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)`
782    repeatedly.  Otherwise it calls it every `timer_interval_secs`
783    seconds.  The thread terminates when a stop is requested.
784
785    The started thread is added to the list of threads managed by the supervisor
786    so it does not need to be passed to the `stop()` method.
787
788    Args:
789      timer_interval_secs: Number. Time boundaries at which to call `target`.
790      target: A callable object.
791      args: Optional arguments to pass to `target` when calling it.
792      kwargs: Optional keyword arguments to pass to `target` when calling it.
793
794    Returns:
795      The started thread.
796    """
797    looper = coordinator.LooperThread(
798        self._coord,
799        timer_interval_secs,
800        target=target,
801        args=args,
802        kwargs=kwargs)
803    looper.start()
804    return looper
805
806  def stop(self,
807           threads=None,
808           close_summary_writer=True,
809           ignore_live_threads=False):
810    """Stop the services and the coordinator.
811
812    This does not close the session.
813
814    Args:
815      threads: Optional list of threads to join with the coordinator.  If
816        `None`, defaults to the threads running the standard services, the
817        threads started for `QueueRunners`, and the threads started by the
818        `loop()` method.  To wait on additional threads, pass the list in this
819        parameter.
820      close_summary_writer: Whether to close the `summary_writer`.  Defaults to
821        `True` if the summary writer was created by the supervisor, `False`
822        otherwise.
823      ignore_live_threads: If `True` ignores threads that remain running after a
824        grace period when joining threads via the coordinator, instead of
825        raising a RuntimeError.
826    """
827    self._coord.request_stop()
828    try:
829      # coord.join() re-raises the first reported exception; the "finally"
830      # block ensures that we clean up whether or not an exception was
831      # reported.
832      self._coord.join(
833          threads,
834          stop_grace_period_secs=self._stop_grace_secs,
835          ignore_live_threads=ignore_live_threads)
836    finally:
837      # Close the writer last, in case one of the running threads was using it.
838      if close_summary_writer and self._summary_writer:
839        # Stop messages are not logged with event.step,
840        # since the session may have already terminated.
841        self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP))
842        self._summary_writer.close()
843        self._graph_added_to_summary = False
844
845  def request_stop(self, ex=None):
846    """Request that the coordinator stop the threads.
847
848    See `Coordinator.request_stop()`.
849
850    Args:
851      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
852        `sys.exc_info()`.  If this is the first call to `request_stop()` the
853        corresponding exception is recorded and re-raised from `join()`.
854    """
855    self._coord.request_stop(ex=ex)
856
857  def should_stop(self):
858    """Check if the coordinator was told to stop.
859
860    See `Coordinator.should_stop()`.
861
862    Returns:
863      True if the coordinator was told to stop, False otherwise.
864    """
865    return self._coord.should_stop()
866
867  def stop_on_exception(self):
868    """Context handler to stop the supervisor when an exception is raised.
869
870    See `Coordinator.stop_on_exception()`.
871
872    Returns:
873      A context handler.
874    """
875    return self._coord.stop_on_exception()
876
877  def wait_for_stop(self):
878    """Block waiting for the coordinator to stop."""
879    self._coord.wait_for_stop()
880
881  def summary_computed(self, sess, summary, global_step=None):
882    """Indicate that a summary was computed.
883
884    Args:
885      sess: A `Session` object.
886      summary: A Summary proto, or a string holding a serialized summary proto.
887      global_step: Int. global step this summary is associated with. If `None`,
888        it will try to fetch the current step.
889
890    Raises:
891      TypeError: if 'summary' is not a Summary proto or a string.
892      RuntimeError: if the Supervisor was created without a `logdir`.
893    """
894    if not self._summary_writer:
895      raise RuntimeError("Writing a summary requires a summary writer.")
896    if global_step is None and self.global_step is not None:
897      global_step = training_util.global_step(sess, self.global_step)
898    self._summary_writer.add_summary(summary, global_step)
899
900  def _default_global_step_tensor(self):
901    """Returns the global_step from the default graph.
902
903    Returns:
904      The global step `Tensor` or `None`.
905    """
906    try:
907      gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
908      if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
909        return gs
910      else:
911        logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
912        return None
913    except KeyError:
914      return None
915
916  def _verify_setup(self):
917    """Check that all is good.
918
919    Raises:
920      ValueError: If something is not good.
921    """
922    # Not running as chief means that replicas are used.
923    # In that case all Variables must have their device set.
924    if not self._is_chief:
925      for op in self._graph.get_operations():
926        if op.type in ["Variable", "VariableV2"] and not op.device:
927          raise ValueError("When using replicas, all Variables must have "
928                           "their device set: %s" % op)
929
930  # pylint: disable=g-doc-return-or-yield,broad-except
931  @contextlib.contextmanager
932  def managed_session(self,
933                      master="",
934                      config=None,
935                      start_standard_services=True,
936                      close_summary_writer=True):
937    """Returns a context manager for a managed session.
938
939    This context manager creates and automatically recovers a session.  It
940    optionally starts the standard services that handle checkpoints and
941    summaries.  It monitors exceptions raised from the `with` block or from the
942    services and stops the supervisor as needed.
943
944    The context manager is typically used as follows:
945
946    ```python
947    def train():
948      sv = tf.compat.v1.train.Supervisor(...)
949      with sv.managed_session(<master>) as sess:
950        for step in range(..):
951          if sv.should_stop():
952            break
953          sess.run(<my training op>)
954          ...do other things needed at each training step...
955    ```
956
957    An exception raised from the `with` block or one of the service threads is
958    raised again when the block exits.  This is done after stopping all threads
959    and closing the session.  For example, an `AbortedError` exception, raised
960    in case of preemption of one of the workers in a distributed model, is
961    raised again when the block exits.
962
963    If you want to retry the training loop in case of preemption you can do it
964    as follows:
965
966    ```python
967    def main(...):
968      while True
969        try:
970          train()
971        except tf.errors.Aborted:
972          pass
973    ```
974
975    As a special case, exceptions used for control flow, such as
976    `OutOfRangeError` which reports that input queues are exhausted, are not
977    raised again from the `with` block: they indicate a clean termination of
978    the training loop and are considered normal termination.
979
980    Args:
981      master: name of the TensorFlow master to use.  See the
982        `tf.compat.v1.Session` constructor for how this is interpreted.
983      config: Optional `ConfigProto` proto used to configure the session. Passed
984        as-is to create the session.
985      start_standard_services: Whether to start the standard services, such as
986        checkpoint, summary and step counter.
987      close_summary_writer: Whether to close the summary writer when closing the
988        session.  Defaults to True.
989
990    Returns:
991      A context manager that yields a `Session` restored from the latest
992      checkpoint or initialized from scratch if not checkpoint exists.  The
993      session is closed when the `with` block exits.
994    """
995    try:
996      sess = self.prepare_or_wait_for_session(
997          master=master,
998          config=config,
999          start_standard_services=start_standard_services)
1000      yield sess
1001    except Exception as e:
1002      self.request_stop(e)
1003    finally:
1004      try:
1005        # Request all the threads to stop and wait for them to do so.  Any
1006        # exception raised by the threads is raised again from stop().
1007        # Passing stop_grace_period_secs is for blocked enqueue/dequeue
1008        # threads which are not checking for `should_stop()`.  They
1009        # will be stopped when we close the session further down.
1010        self.stop(close_summary_writer=close_summary_writer)
1011      finally:
1012        # Close the session to finish up all pending calls.  We do not care
1013        # about exceptions raised when closing.  This takes care of
1014        # blocked enqueue/dequeue calls.
1015        try:
1016          sess.close()
1017        except Exception:
1018          # Silently ignore exceptions raised by close().
1019          pass
1020
1021  # pylint: enable=g-doc-return-or-yield,broad-except
1022
1023
1024class SVSummaryThread(coordinator.LooperThread):
1025  """A thread to save summaries on a timer."""
1026
1027  def __init__(self, sv, sess):
1028    """Create a SVSummaryThread.
1029
1030    Args:
1031      sv: A `Supervisor`.
1032      sess: A `Session`.
1033    """
1034    super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs)
1035    self._sv = sv
1036    self._sess = sess
1037
1038  def run_loop(self):
1039    if self._sv.global_step is not None:
1040      summary_strs, global_step = self._sess.run(
1041          [self._sv.summary_op, self._sv.global_step])
1042    else:
1043      summary_strs = self._sess.run(self._sv.summary_op)
1044      global_step = None
1045    if self._sv.summary_writer:
1046      logging.info("Recording summary at step %s.", global_step)
1047      self._sv.summary_writer.add_summary(summary_strs, global_step)
1048
1049
1050class SVStepCounterThread(coordinator.LooperThread):
1051  """Threads to count steps and measure their duration."""
1052
1053  def __init__(self, sv, sess, step_counter=None):
1054    """Create a `SVStepCounterThread`.
1055
1056    Args:
1057      sv: A `Supervisor`.
1058      sess: A `Session`.
1059      step_counter: A `Tensor` holding the step counter. By defaults, it uses
1060        sv.global_step.
1061    """
1062    super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs)
1063    self._sv = sv
1064    self._sess = sess
1065    self._last_time = 0.0
1066    self._last_step = 0
1067    step_counter = sv.global_step if step_counter is None else step_counter
1068    self._step_counter = step_counter
1069    self._summary_tag = "%s/sec" % self._step_counter.op.name
1070
1071  def start_loop(self):
1072    self._last_time = time.time()
1073    self._last_step = training_util.global_step(self._sess, self._step_counter)
1074
1075  def run_loop(self):
1076    # Count the steps.
1077    current_step = training_util.global_step(self._sess, self._step_counter)
1078    added_steps = current_step - self._last_step
1079    self._last_step = current_step
1080    # Measure the elapsed time.
1081    current_time = time.time()
1082    elapsed_time = current_time - self._last_time
1083    self._last_time = current_time
1084    # Reports the number of steps done per second
1085    if elapsed_time > 0.:
1086      steps_per_sec = added_steps / elapsed_time
1087    else:
1088      steps_per_sec = float("inf")
1089    summary = Summary(value=[
1090        Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
1091    ])
1092    if self._sv.summary_writer:
1093      self._sv.summary_writer.add_summary(summary, current_step)
1094    logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag,
1095                        steps_per_sec)
1096
1097
1098class SVTimerCheckpointThread(coordinator.LooperThread):
1099  """A thread to checkpoint on a timer."""
1100
1101  def __init__(self, sv, sess):
1102    """Create a `SVTimerCheckpointThread`.
1103
1104    Args:
1105      sv: A `Supervisor`.
1106      sess: A `Session`.
1107    """
1108    super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs)
1109    self._sv = sv
1110    self._sess = sess
1111
1112  def run_loop(self):
1113    logging.info("Saving checkpoint to path %s", self._sv.save_path)
1114    self._sv.saver.save(
1115        self._sess, self._sv.save_path, global_step=self._sv.global_step)
1116    if self._sv.summary_writer and self._sv.global_step is not None:
1117      current_step = training_util.global_step(self._sess, self._sv.global_step)
1118      self._sv.summary_writer.add_session_log(
1119          SessionLog(
1120              status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path),
1121          current_step)
1122
1123
1124# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly.
1125setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session)
1126setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners)
1127setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services)
1128setattr(Supervisor, "Stop", Supervisor.stop)
1129setattr(Supervisor, "RequestStop", Supervisor.request_stop)
1130setattr(Supervisor, "Loop", Supervisor.loop)
1131setattr(Supervisor, "ShouldStop", Supervisor.should_stop)
1132setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception)
1133setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop)
1134setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed)
1135