• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and computes summaries."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21import os
22import time
23
24from tensorflow.core.framework.summary_pb2 import Summary
25from tensorflow.core.util.event_pb2 import SessionLog
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import meta_graph
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import lookup_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.summary import summary as _summary
35from tensorflow.python.training import coordinator
36from tensorflow.python.training import saver as saver_mod
37from tensorflow.python.training import session_manager as session_manager_mod
38from tensorflow.python.training import training_util
39from tensorflow.python.util import deprecation
40from tensorflow.python.util.tf_export import tf_export
41
42
43@tf_export(v1=["train.Supervisor"])
44class Supervisor(object):
45  """A training helper that checkpoints models and computes summaries.
46
47  This class is deprecated. Please use
48  `tf.train.MonitoredTrainingSession` instead.
49
50  The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
51  and a `SessionManager` that takes care of common needs of TensorFlow
52  training programs.
53
54  #### Use for a single program
55
56  ```python
57  with tf.Graph().as_default():
58    ...add operations to the graph...
59    # Create a Supervisor that will checkpoint the model in '/tmp/mydir'.
60    sv = Supervisor(logdir='/tmp/mydir')
61    # Get a TensorFlow session managed by the supervisor.
62    with sv.managed_session(FLAGS.master) as sess:
63      # Use the session to train the graph.
64      while not sv.should_stop():
65        sess.run(<my_train_op>)
66  ```
67
68  Within the `with sv.managed_session()` block all variables in the graph have
69  been initialized.  In addition, a few services have been started to
70  checkpoint the model and add summaries to the event log.
71
72  If the program crashes and is restarted, the managed session automatically
73  reinitialize variables from the most recent checkpoint.
74
75  The supervisor is notified of any exception raised by one of the services.
76  After an exception is raised, `should_stop()` returns `True`.  In that case
77  the training loop should also stop.  This is why the training loop has to
78  check for `sv.should_stop()`.
79
80  Exceptions that indicate that the training inputs have been exhausted,
81  `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True`
82  but are not re-raised from the `with` block: they indicate a normal
83  termination.
84
85  #### Use for multiple replicas
86
87  To train with replicas you deploy the same program in a `Cluster`.
88  One of the tasks must be identified as the *chief*: the task that handles
89  initialization, checkpoints, summaries, and recovery.  The other tasks
90  depend on the *chief* for these services.
91
92  The only change you have to do to the single program code is to indicate
93  if the program is running as the *chief*.
94
95  ```python
96  # Choose a task as the chief. This could be based on server_def.task_index,
97  # or job_def.name, or job_def.tasks. It's entirely up to the end user.
98  # But there can be only one *chief*.
99  is_chief = (server_def.task_index == 0)
100  server = tf.train.Server(server_def)
101
102  with tf.Graph().as_default():
103    ...add operations to the graph...
104    # Create a Supervisor that uses log directory on a shared file system.
105    # Indicate if you are the 'chief'
106    sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief)
107    # Get a Session in a TensorFlow server on the cluster.
108    with sv.managed_session(server.target) as sess:
109      # Use the session to train the graph.
110      while not sv.should_stop():
111        sess.run(<my_train_op>)
112  ```
113
114  In the *chief* task, the `Supervisor` works exactly as in the first example
115  above.  In the other tasks `sv.managed_session()` waits for the Model to have
116  been initialized before returning a session to the training code.  The
117  non-chief tasks depend on the chief task for initializing the model.
118
119  If one of the tasks crashes and restarts, `managed_session()`
120  checks if the Model is initialized.  If yes, it just creates a session and
121  returns it to the training code that proceeds normally.  If the model needs
122  to be initialized, the chief task takes care of reinitializing it; the other
123  tasks just wait for the model to have been initialized.
124
125  NOTE: This modified program still works fine as a single program.
126  The single program marks itself as the chief.
127
128  #### What `master` string to use
129
130  Whether you are running on your machine or in the cluster you can use the
131  following values for the --master flag:
132
133  * Specifying `''` requests an in-process session that does not use RPC.
134
135  * Specifying `'local'` requests a session that uses the RPC-based
136    "Master interface" to run TensorFlow programs. See
137    `tf.train.Server.create_local_server` for
138    details.
139
140  * Specifying `'grpc://hostname:port'` requests a session that uses
141    the RPC interface to a specific host, and also allows the in-process
142    master to access remote tensorflow workers. Often, it is
143    appropriate to pass `server.target` (for some `tf.train.Server`
144    named `server).
145
146  #### Advanced use
147
148  ##### Launching additional services
149
150  `managed_session()` launches the Checkpoint and Summary services (threads).
151  If you need more services to run you can simply launch them in the block
152  controlled by `managed_session()`.
153
154  Example: Start a thread to print losses.  We want this thread to run
155  every 60 seconds, so we launch it with `sv.loop()`.
156
157  ```python
158  ...
159  sv = Supervisor(logdir='/tmp/mydir')
160  with sv.managed_session(FLAGS.master) as sess:
161    sv.loop(60, print_loss, (sess, ))
162    while not sv.should_stop():
163      sess.run(my_train_op)
164  ```
165
166  ##### Launching fewer services
167
168  `managed_session()` launches the "summary" and "checkpoint" threads which use
169  either the optionally `summary_op` and `saver` passed to the constructor, or
170  default ones created automatically by the supervisor.  If you want to run
171  your own summary and checkpointing logic, disable these services by passing
172  `None` to the `summary_op` and `saver` parameters.
173
174  Example: Create summaries manually every 100 steps in the chief.
175
176  ```python
177  # Create a Supervisor with no automatic summaries.
178  sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None)
179  # As summary_op was None, managed_session() does not start the
180  # summary thread.
181  with sv.managed_session(FLAGS.master) as sess:
182    for step in xrange(1000000):
183      if sv.should_stop():
184        break
185      if is_chief and step % 100 == 0:
186        # Create the summary every 100 chief steps.
187        sv.summary_computed(sess, sess.run(my_summary_op))
188      else:
189        # Train normally
190        sess.run(my_train_op)
191  ```
192
193  ##### Custom model initialization
194
195  `managed_session()` only supports initializing the model by running an
196  `init_op` or restoring from the latest checkpoint.  If you have special
197  initialization needs, see how to specify a `local_init_op` when creating the
198  supervisor.  You can also use the `SessionManager` directly to create a
199  session and check if it could be initialized automatically.
200  """
201
202  # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver',
203  # and 'global_step' parameters of Supervisor.__init__() to indicate that
204  # the default behavior should be used.
205  USE_DEFAULT = 0
206
207  @deprecation.deprecated(None,
208                          "Please switch to tf.train.MonitoredTrainingSession")
209  def __init__(self,
210               graph=None,
211               ready_op=USE_DEFAULT,
212               ready_for_local_init_op=USE_DEFAULT,
213               is_chief=True,
214               init_op=USE_DEFAULT,
215               init_feed_dict=None,
216               local_init_op=USE_DEFAULT,
217               logdir=None,
218               summary_op=USE_DEFAULT,
219               saver=USE_DEFAULT,
220               global_step=USE_DEFAULT,
221               save_summaries_secs=120,
222               save_model_secs=600,
223               recovery_wait_secs=30,
224               stop_grace_secs=120,
225               checkpoint_basename="model.ckpt",
226               session_manager=None,
227               summary_writer=USE_DEFAULT,
228               init_fn=None,
229               local_init_run_options=None):
230    """Create a `Supervisor`.
231
232    Args:
233      graph: A `Graph`.  The graph that the model will use.  Defaults to the
234        default `Graph`.  The supervisor may add operations to the graph before
235        creating a session, but the graph should not be modified by the caller
236        after passing it to the supervisor.
237      ready_op: 1-D string `Tensor`.  This tensor is evaluated by supervisors in
238        `prepare_or_wait_for_session()` to check if the model is ready to use.
239        The model is considered ready if it returns an empty array.  Defaults to
240        the tensor returned from `tf.report_uninitialized_variables()`  If
241        `None`, the model is not checked for readiness.
242      ready_for_local_init_op: 1-D string `Tensor`.  This tensor is evaluated by
243        supervisors in `prepare_or_wait_for_session()` to check if the model is
244        ready to run the local_init_op.
245        The model is considered ready if it returns an empty array. Defaults to
246        `None`. If `None`, the model is not checked for readiness before running
247        local_init_op.
248      is_chief: If True, create a chief supervisor in charge of initializing
249        and restoring the model.  If False, create a supervisor that relies
250        on a chief supervisor for inits and restore.
251      init_op: `Operation`.  Used by chief supervisors to initialize the model
252        when it can not be recovered.  Defaults to an `Operation` that
253        initializes all global variables.  If `None`, no initialization is done
254        automatically unless you pass a value for `init_fn`, see below.
255      init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
256        This feed dictionary will be used when `init_op` is evaluated.
257      local_init_op: `Operation`. Used by all supervisors to run initializations
258        that should run for every new supervisor instance. By default these
259        are table initializers and initializers for local variables.
260        If `None`, no further per supervisor-instance initialization is
261        done automatically.
262      logdir: A string.  Optional path to a directory where to checkpoint the
263        model and log events for the visualizer.  Used by chief supervisors.
264        The directory will be created if it does not exist.
265      summary_op: An `Operation` that returns a Summary for the event logs.
266        Used by chief supervisors if a `logdir` was specified.  Defaults to the
267        operation returned from summary.merge_all().  If `None`, summaries are
268        not computed automatically.
269      saver: A Saver object.  Used by chief supervisors if a `logdir` was
270        specified.  Defaults to the saved returned by Saver().
271        If `None`, the model is not saved automatically.
272      global_step: An integer Tensor of size 1 that counts steps.  The value
273        from 'global_step' is used in summaries and checkpoint filenames.
274        Default to the op named 'global_step' in the graph if it exists, is of
275        rank 1, size 1, and of type tf.int32 or tf.int64.  If `None` the global
276        step is not recorded in summaries and checkpoint files.  Used by chief
277        supervisors if a `logdir` was specified.
278      save_summaries_secs: Number of seconds between the computation of
279        summaries for the event log.  Defaults to 120 seconds.  Pass 0 to
280        disable summaries.
281      save_model_secs: Number of seconds between the creation of model
282        checkpoints.  Defaults to 600 seconds.  Pass 0 to disable checkpoints.
283      recovery_wait_secs: Number of seconds between checks that the model
284        is ready.  Used by supervisors when waiting for a chief supervisor
285        to initialize or restore the model.  Defaults to 30 seconds.
286      stop_grace_secs: Grace period, in seconds, given to running threads to
287        stop when `stop()` is called.  Defaults to 120 seconds.
288      checkpoint_basename: The basename for checkpoint saving.
289      session_manager: `SessionManager`, which manages Session creation and
290        recovery. If it is `None`, a default `SessionManager` will be created
291        with the set of arguments passed in for backwards compatibility.
292      summary_writer: `SummaryWriter` to use or `USE_DEFAULT`.  Can be `None`
293        to indicate that no summaries should be written.
294      init_fn: Optional callable used to initialize the model. Called
295        after the optional `init_op` is called.  The callable must accept one
296        argument, the session being initialized.
297      local_init_run_options: RunOptions to be passed as the SessionManager
298        local_init_run_options parameter.
299
300    Returns:
301      A `Supervisor`.
302
303    Raises:
304      RuntimeError: If called with eager execution enabled.
305
306    @compatibility(eager)
307    `Supervisor`s are not supported when eager execution is enabled.
308    @end_compatibility
309    """
310    if context.executing_eagerly():
311      raise RuntimeError("Supervisors are compatible with eager execution.")
312    # Set default values of arguments.
313    if graph is None:
314      graph = ops.get_default_graph()
315    with graph.as_default():
316      self._init_ready_op(
317          ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op)
318      self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict)
319      self._init_local_init_op(local_init_op=local_init_op)
320      self._init_saver(saver=saver)
321      self._init_summary_op(summary_op=summary_op)
322      self._init_global_step(global_step=global_step)
323    self._graph = graph
324    self._meta_graph_def = meta_graph.create_meta_graph_def(
325        graph_def=graph.as_graph_def(add_shapes=True),
326        saver_def=self._saver.saver_def if self._saver else None)
327    self._is_chief = is_chief
328    self._coord = coordinator.Coordinator()
329    self._recovery_wait_secs = recovery_wait_secs
330    self._stop_grace_secs = stop_grace_secs
331    self._init_fn = init_fn
332    self._local_init_run_options = local_init_run_options
333
334    # Set all attributes related to checkpointing and writing events to None.
335    # Afterwards, set them appropriately for chief supervisors, as these are
336    # the only supervisors that can write checkpoints and events.
337    self._logdir = None
338    self._save_summaries_secs = None
339    self._save_model_secs = None
340    self._save_path = None
341    self._summary_writer = None
342
343    if self._is_chief:
344      self._logdir = logdir
345      self._save_summaries_secs = save_summaries_secs
346      self._save_model_secs = save_model_secs
347      if self._logdir:
348        self._save_path = os.path.join(self._logdir, checkpoint_basename)
349      if summary_writer is Supervisor.USE_DEFAULT:
350        if self._logdir:
351          self._summary_writer = _summary.FileWriter(self._logdir)
352      else:
353        self._summary_writer = summary_writer
354      self._graph_added_to_summary = False
355
356    self._init_session_manager(session_manager=session_manager)
357    self._verify_setup()
358    # The graph is not allowed to change anymore.
359    graph.finalize()
360
361  def _init_session_manager(self, session_manager=None):
362    if session_manager is None:
363      self._session_manager = session_manager_mod.SessionManager(
364          local_init_op=self._local_init_op,
365          ready_op=self._ready_op,
366          ready_for_local_init_op=self._ready_for_local_init_op,
367          graph=self._graph,
368          recovery_wait_secs=self._recovery_wait_secs,
369          local_init_run_options=self._local_init_run_options)
370    else:
371      self._session_manager = session_manager
372
373  def _get_first_op_from_collection(self, key):
374    """Returns the first `Operation` from a collection.
375
376    Args:
377      key: A string collection key.
378
379    Returns:
380      The first Op found in a collection, or `None` if the collection is empty.
381    """
382    try:
383      op_list = ops.get_collection(key)
384      if len(op_list) > 1:
385        logging.info("Found %d %s operations. Returning the first one.",
386                     len(op_list), key)
387      if op_list:
388        return op_list[0]
389    except LookupError:
390      pass
391
392    return None
393
394  def _init_ready_op(self,
395                     ready_op=USE_DEFAULT,
396                     ready_for_local_init_op=USE_DEFAULT):
397    """Initializes ready_op.
398
399    Args:
400      ready_op: `Tensor` to check if the model is initialized.
401        If it's set to USE_DEFAULT, creates an op that checks all
402        the variables are initialized.
403      ready_for_local_init_op: `Tensor` to check if the model is ready to run
404        local_init_op.
405        If it's set to USE_DEFAULT, creates an op that checks all
406        the global variables are initialized.
407    """
408    if ready_op is Supervisor.USE_DEFAULT:
409      ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
410      if ready_op is None:
411        ready_op = variables.report_uninitialized_variables()
412        ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
413    self._ready_op = ready_op
414
415    # ready_for_local_init_op defaults to None for backward compatibility
416    if ready_for_local_init_op is Supervisor.USE_DEFAULT:
417      ready_for_local_init_op = self._get_first_op_from_collection(
418          ops.GraphKeys.READY_FOR_LOCAL_INIT_OP)
419    self._ready_for_local_init_op = ready_for_local_init_op
420
421  def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None):
422    """Initializes init_op.
423
424    Args:
425      init_op: `Operation` to initialize the variables. If set to USE_DEFAULT,
426        create an op that initializes all variables and tables.
427      init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
428        This feed dictionary will be used when `init_op` is evaluated.
429    """
430    if init_op is Supervisor.USE_DEFAULT:
431      init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP)
432      if init_op is None:
433        init_op = variables.global_variables_initializer()
434        ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op)
435    self._init_op = init_op
436    self._init_feed_dict = init_feed_dict
437
438  def _init_local_init_op(self, local_init_op=USE_DEFAULT):
439    """Initializes local_init_op.
440
441    Args:
442      local_init_op: `Operation` run for every new supervisor instance. If set
443      to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
444      collection. If the collection is empty, create an op that initializes
445      all local variables and all tables.
446    """
447    if local_init_op is Supervisor.USE_DEFAULT:
448      local_init_op = self._get_first_op_from_collection(
449          ops.GraphKeys.LOCAL_INIT_OP)
450      if local_init_op is None:
451        op_list = [
452            variables.local_variables_initializer(),
453            lookup_ops.tables_initializer()
454        ]
455        if op_list:
456          local_init_op = control_flow_ops.group(*op_list)
457          ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
458    self._local_init_op = local_init_op
459
460  def _init_saver(self, saver=USE_DEFAULT):
461    """Initializes saver.
462
463    Args:
464      saver: A `Saver` object. If set to USE_DEFAULT, create one that
465        saves all the variables.
466    """
467    if saver is Supervisor.USE_DEFAULT:
468      saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
469      if saver is None and variables.global_variables():
470        saver = saver_mod.Saver()
471        ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
472    self._saver = saver
473
474  def _init_summary_op(self, summary_op=USE_DEFAULT):
475    """Initializes summary_op.
476
477    Args:
478      summary_op: An Operation that returns a Summary for the event logs.
479        If set to USE_DEFAULT, create an op that merges all the summaries.
480    """
481    if summary_op is Supervisor.USE_DEFAULT:
482      summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
483      if summary_op is None:
484        summary_op = _summary.merge_all()
485        if summary_op is not None:
486          ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
487    self._summary_op = summary_op
488
489  def _init_global_step(self, global_step=USE_DEFAULT):
490    """Initializes global_step.
491
492    Args:
493      global_step: An integer Tensor of size 1 that counts steps. If
494        set to USE_DEFAULT, creates global_step tensor.
495    """
496    if global_step is Supervisor.USE_DEFAULT:
497      global_step = self._get_first_op_from_collection(
498          ops.GraphKeys.GLOBAL_STEP)
499      if global_step is None:
500        global_step = self._default_global_step_tensor()
501        if global_step is not None:
502          ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step)
503    self._global_step = global_step
504
505  @property
506  def is_chief(self):
507    """Return True if this is a chief supervisor.
508
509    Returns:
510      A bool.
511    """
512    return self._is_chief
513
514  @property
515  def session_manager(self):
516    """Return the SessionManager used by the Supervisor.
517
518    Returns:
519      A SessionManager object.
520    """
521    return self._session_manager
522
523  @property
524  def coord(self):
525    """Return the Coordinator used by the Supervisor.
526
527    The Coordinator can be useful if you want to run multiple threads
528    during your training.
529
530    Returns:
531      A Coordinator object.
532    """
533    return self._coord
534
535  @property
536  def init_op(self):
537    """Return the Init Op used by the supervisor.
538
539    Returns:
540      An Op or `None`.
541    """
542    return self._init_op
543
544  @property
545  def init_feed_dict(self):
546    """Return the feed dictionary used when evaluating the `init_op`.
547
548    Returns:
549      A feed dictionary or `None`.
550    """
551    return self._init_feed_dict
552
553  @property
554  def ready_op(self):
555    """Return the Ready Op used by the supervisor.
556
557    Returns:
558      An Op or `None`.
559    """
560    return self._ready_op
561
562  @property
563  def ready_for_local_init_op(self):
564    return self._ready_for_local_init_op
565
566  @property
567  def summary_writer(self):
568    """Return the SummaryWriter used by the chief supervisor.
569
570    Returns:
571      A SummaryWriter.
572    """
573    return self._summary_writer
574
575  @property
576  def summary_op(self):
577    """Return the Summary Tensor used by the chief supervisor.
578
579    Returns:
580      A string Tensor for the summary or `None`.
581    """
582    return self._summary_op
583
584  @property
585  def save_summaries_secs(self):
586    """Return the delay between summary computations.
587
588    Returns:
589      A timestamp.
590    """
591    return self._save_summaries_secs
592
593  @property
594  def global_step(self):
595    """Return the global_step Tensor used by the supervisor.
596
597    Returns:
598      An integer Tensor for the global_step.
599    """
600    return self._global_step
601
602  @property
603  def saver(self):
604    """Return the Saver used by the supervisor.
605
606    Returns:
607      A Saver object.
608    """
609    return self._saver
610
611  @property
612  def save_model_secs(self):
613    """Return the delay between checkpoints.
614
615    Returns:
616      A timestamp.
617    """
618    return self._save_model_secs
619
620  @property
621  def save_path(self):
622    """Return the save path used by the supervisor.
623
624    Returns:
625      A string.
626    """
627    return self._save_path
628
629  def _write_graph(self):
630    """Writes graph_def to `logdir` and adds it to summary if applicable."""
631    assert self._is_chief
632    if self._logdir:
633      training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
634                                self._logdir, "graph.pbtxt")
635    if self._summary_writer and not self._graph_added_to_summary:
636      self._summary_writer.add_graph(self._graph)
637      self._summary_writer.add_meta_graph(self._meta_graph_def)
638      self._graph_added_to_summary = True
639
640  def start_standard_services(self, sess):
641    """Start the standard services for 'sess'.
642
643    This starts services in the background.  The services started depend
644    on the parameters to the constructor and may include:
645
646      - A Summary thread computing summaries every save_summaries_secs.
647      - A Checkpoint thread saving the model every save_model_secs.
648      - A StepCounter thread measure step time.
649
650    Args:
651      sess: A Session.
652
653    Returns:
654      A list of threads that are running the standard services.  You can use
655      the Supervisor's Coordinator to join these threads with:
656        sv.coord.Join(<list of threads>)
657
658    Raises:
659      RuntimeError: If called with a non-chief Supervisor.
660      ValueError: If not `logdir` was passed to the constructor as the
661        services need a log directory.
662    """
663    if not self._is_chief:
664      raise RuntimeError("Only chief supervisor can start standard services. "
665                         "Because only chief supervisors can write events.")
666
667    if not self._logdir:
668      logging.warning("Standard services need a 'logdir' "
669                      "passed to the SessionManager")
670      return
671
672    if self._global_step is not None and self._summary_writer:
673      # Only add the session log if we keep track of global step.
674      # TensorBoard cannot use START message for purging expired events
675      # if there is no step value.
676      current_step = training_util.global_step(sess, self._global_step)
677      self._summary_writer.add_session_log(
678          SessionLog(status=SessionLog.START),
679          current_step)
680
681    threads = []
682    if self._save_summaries_secs and self._summary_writer:
683      if self._summary_op is not None:
684        threads.append(SVSummaryThread(self, sess))
685      if self._global_step is not None:
686        threads.append(SVStepCounterThread(self, sess))
687    if self.saver and self._save_model_secs:
688      threads.append(SVTimerCheckpointThread(self, sess))
689    for t in threads:
690      t.start()
691    return threads
692
693  def prepare_or_wait_for_session(self, master="", config=None,
694                                  wait_for_checkpoint=False,
695                                  max_wait_secs=7200,
696                                  start_standard_services=True):
697    """Make sure the model is ready to be used.
698
699    Create a session on 'master', recovering or initializing the model as
700    needed, or wait for a session to be ready.  If running as the chief
701    and `start_standard_service` is set to True, also call the session
702    manager to start the standard services.
703
704    Args:
705      master: name of the TensorFlow master to use.  See the `tf.Session`
706        constructor for how this is interpreted.
707      config: Optional ConfigProto proto used to configure the session,
708        which is passed as-is to create the session.
709      wait_for_checkpoint: Whether we should wait for the availability of a
710        checkpoint before creating Session. Defaults to False.
711      max_wait_secs: Maximum time to wait for the session to become available.
712      start_standard_services: Whether to start the standard services and the
713        queue runners.
714
715    Returns:
716      A Session object that can be used to drive the model.
717    """
718    # For users who recreate the session with prepare_or_wait_for_session(), we
719    # need to clear the coordinator's stop_event so that threads managed by the
720    # coordinator can run.
721    self._coord.clear_stop()
722    if self._summary_writer:
723      self._summary_writer.reopen()
724
725    if self._is_chief:
726      sess = self._session_manager.prepare_session(
727          master, init_op=self.init_op, saver=self.saver,
728          checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint,
729          max_wait_secs=max_wait_secs, config=config,
730          init_feed_dict=self._init_feed_dict, init_fn=self._init_fn)
731      self._write_graph()
732      if start_standard_services:
733        logging.info("Starting standard services.")
734        self.start_standard_services(sess)
735    else:
736      sess = self._session_manager.wait_for_session(master,
737                                                    config=config,
738                                                    max_wait_secs=max_wait_secs)
739    if start_standard_services:
740      logging.info("Starting queue runners.")
741      self.start_queue_runners(sess)
742    return sess
743
744  def start_queue_runners(self, sess, queue_runners=None):
745    """Start threads for `QueueRunners`.
746
747    Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
748    are already started automatically when you create a session with the
749    supervisor, so unless you have non-collected queue runners to start
750    you do not need to call this explicitly.
751
752    Args:
753      sess: A `Session`.
754      queue_runners: A list of `QueueRunners`. If not specified, we'll use the
755        list of queue runners gathered in the graph under the key
756        `GraphKeys.QUEUE_RUNNERS`.
757
758    Returns:
759      The list of threads started for the `QueueRunners`.
760
761    Raises:
762      RuntimeError: If called with eager execution enabled.
763
764    @compatibility(eager)
765    Queues are not compatible with eager execution. To ingest data when eager
766    execution is enabled, use the `tf.data` API.
767    @end_compatibility
768    """
769    if context.executing_eagerly():
770      raise RuntimeError("Queues are not compatible with eager execution.")
771    if queue_runners is None:
772      queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
773    threads = []
774    for qr in queue_runners:
775      threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
776                                       start=True))
777    return threads
778
779  def loop(self, timer_interval_secs, target, args=None, kwargs=None):
780    """Start a LooperThread that calls a function periodically.
781
782    If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)`
783    repeatedly.  Otherwise it calls it every `timer_interval_secs`
784    seconds.  The thread terminates when a stop is requested.
785
786    The started thread is added to the list of threads managed by the supervisor
787    so it does not need to be passed to the `stop()` method.
788
789    Args:
790      timer_interval_secs: Number. Time boundaries at which to call `target`.
791      target: A callable object.
792      args: Optional arguments to pass to `target` when calling it.
793      kwargs: Optional keyword arguments to pass to `target` when calling it.
794
795    Returns:
796      The started thread.
797    """
798    looper = coordinator.LooperThread(self._coord, timer_interval_secs,
799                                      target=target, args=args, kwargs=kwargs)
800    looper.start()
801    return looper
802
803  def stop(self,
804           threads=None,
805           close_summary_writer=True,
806           ignore_live_threads=False):
807    """Stop the services and the coordinator.
808
809    This does not close the session.
810
811    Args:
812      threads: Optional list of threads to join with the coordinator.  If
813        `None`, defaults to the threads running the standard services, the
814        threads started for `QueueRunners`, and the threads started by the
815        `loop()` method.  To wait on additional threads, pass the
816        list in this parameter.
817      close_summary_writer: Whether to close the `summary_writer`.  Defaults to
818        `True` if the summary writer was created by the supervisor, `False`
819        otherwise.
820      ignore_live_threads: If `True` ignores threads that remain running after
821        a grace period when joining threads via the coordinator, instead of
822        raising a RuntimeError.
823    """
824    self._coord.request_stop()
825    try:
826      # coord.join() re-raises the first reported exception; the "finally"
827      # block ensures that we clean up whether or not an exception was
828      # reported.
829      self._coord.join(
830          threads,
831          stop_grace_period_secs=self._stop_grace_secs,
832          ignore_live_threads=ignore_live_threads)
833    finally:
834      # Close the writer last, in case one of the running threads was using it.
835      if close_summary_writer and self._summary_writer:
836        # Stop messages are not logged with event.step,
837        # since the session may have already terminated.
838        self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP))
839        self._summary_writer.close()
840        self._graph_added_to_summary = False
841
842  def request_stop(self, ex=None):
843    """Request that the coordinator stop the threads.
844
845    See `Coordinator.request_stop()`.
846
847    Args:
848      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
849        `sys.exc_info()`.  If this is the first call to `request_stop()` the
850        corresponding exception is recorded and re-raised from `join()`.
851    """
852    self._coord.request_stop(ex=ex)
853
854  def should_stop(self):
855    """Check if the coordinator was told to stop.
856
857    See `Coordinator.should_stop()`.
858
859    Returns:
860      True if the coordinator was told to stop, False otherwise.
861    """
862    return self._coord.should_stop()
863
864  def stop_on_exception(self):
865    """Context handler to stop the supervisor when an exception is raised.
866
867    See `Coordinator.stop_on_exception()`.
868
869    Returns:
870      A context handler.
871    """
872    return self._coord.stop_on_exception()
873
874  def wait_for_stop(self):
875    """Block waiting for the coordinator to stop."""
876    self._coord.wait_for_stop()
877
878  def summary_computed(self, sess, summary, global_step=None):
879    """Indicate that a summary was computed.
880
881    Args:
882      sess: A `Session` object.
883      summary: A Summary proto, or a string holding a serialized summary proto.
884      global_step: Int. global step this summary is associated with. If `None`,
885        it will try to fetch the current step.
886
887    Raises:
888      TypeError: if 'summary' is not a Summary proto or a string.
889      RuntimeError: if the Supervisor was created without a `logdir`.
890    """
891    if not self._summary_writer:
892      raise RuntimeError("Writing a summary requires a summary writer.")
893    if global_step is None and self.global_step is not None:
894      global_step = training_util.global_step(sess, self.global_step)
895    self._summary_writer.add_summary(summary, global_step)
896
897  def _default_global_step_tensor(self):
898    """Returns the global_step from the default graph.
899
900    Returns:
901      The global step `Tensor` or `None`.
902    """
903    try:
904      gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
905      if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
906        return gs
907      else:
908        logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
909        return None
910    except KeyError:
911      return None
912
913  def _verify_setup(self):
914    """Check that all is good.
915
916    Raises:
917      ValueError: If something is not good.
918    """
919    # Not running as chief means that replicas are used.
920    # In that case all Variables must have their device set.
921    if not self._is_chief:
922      for op in self._graph.get_operations():
923        if op.type in ["Variable", "VariableV2"] and not op.device:
924          raise ValueError("When using replicas, all Variables must have "
925                           "their device set: %s" % op)
926
927  # pylint: disable=g-doc-return-or-yield,broad-except
928  @contextlib.contextmanager
929  def managed_session(self, master="", config=None,
930                      start_standard_services=True,
931                      close_summary_writer=True):
932    """Returns a context manager for a managed session.
933
934    This context manager creates and automatically recovers a session.  It
935    optionally starts the standard services that handle checkpoints and
936    summaries.  It monitors exceptions raised from the `with` block or from the
937    services and stops the supervisor as needed.
938
939    The context manager is typically used as follows:
940
941    ```python
942    def train():
943      sv = tf.train.Supervisor(...)
944      with sv.managed_session(<master>) as sess:
945        for step in xrange(..):
946          if sv.should_stop():
947            break
948          sess.run(<my training op>)
949          ...do other things needed at each training step...
950    ```
951
952    An exception raised from the `with` block or one of the service threads is
953    raised again when the block exits.  This is done after stopping all threads
954    and closing the session.  For example, an `AbortedError` exception, raised
955    in case of preemption of one of the workers in a distributed model, is
956    raised again when the block exits.
957
958    If you want to retry the training loop in case of preemption you can do it
959    as follows:
960
961    ```python
962    def main(...):
963      while True
964        try:
965          train()
966        except tf.errors.Aborted:
967          pass
968    ```
969
970    As a special case, exceptions used for control flow, such as
971    `OutOfRangeError` which reports that input queues are exhausted, are not
972    raised again from the `with` block: they indicate a clean termination of
973    the training loop and are considered normal termination.
974
975    Args:
976      master: name of the TensorFlow master to use.  See the `tf.Session`
977        constructor for how this is interpreted.
978      config: Optional `ConfigProto` proto used to configure the session.
979        Passed as-is to create the session.
980      start_standard_services: Whether to start the standard services,
981        such as checkpoint, summary and step counter.
982      close_summary_writer: Whether to close the summary writer when
983        closing the session.  Defaults to True.
984
985    Returns:
986      A context manager that yields a `Session` restored from the latest
987      checkpoint or initialized from scratch if not checkpoint exists.  The
988      session is closed when the `with` block exits.
989    """
990    try:
991      sess = self.prepare_or_wait_for_session(
992          master=master, config=config,
993          start_standard_services=start_standard_services)
994      yield sess
995    except Exception as e:
996      self.request_stop(e)
997    finally:
998      try:
999        # Request all the threads to stop and wait for them to do so.  Any
1000        # exception raised by the threads is raised again from stop().
1001        # Passing stop_grace_period_secs is for blocked enqueue/dequeue
1002        # threads which are not checking for `should_stop()`.  They
1003        # will be stopped when we close the session further down.
1004        self.stop(close_summary_writer=close_summary_writer)
1005      finally:
1006        # Close the session to finish up all pending calls.  We do not care
1007        # about exceptions raised when closing.  This takes care of
1008        # blocked enqueue/dequeue calls.
1009        try:
1010          sess.close()
1011        except Exception:
1012          # Silently ignore exceptions raised by close().
1013          pass
1014  # pylint: enable=g-doc-return-or-yield,broad-except
1015
1016
1017class SVSummaryThread(coordinator.LooperThread):
1018  """A thread to save summaries on a timer."""
1019
1020  def __init__(self, sv, sess):
1021    """Create a SVSummaryThread.
1022
1023    Args:
1024      sv: A `Supervisor`.
1025      sess: A `Session`.
1026    """
1027    super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs)
1028    self._sv = sv
1029    self._sess = sess
1030
1031  def run_loop(self):
1032    if self._sv.global_step is not None:
1033      summary_strs, global_step = self._sess.run([self._sv.summary_op,
1034                                                  self._sv.global_step])
1035    else:
1036      summary_strs = self._sess.run(self._sv.summary_op)
1037      global_step = None
1038    if self._sv.summary_writer:
1039      logging.info("Recording summary at step %s.", global_step)
1040      self._sv.summary_writer.add_summary(summary_strs, global_step)
1041
1042
1043class SVStepCounterThread(coordinator.LooperThread):
1044  """Threads to count steps and measure their duration."""
1045
1046  def __init__(self, sv, sess, step_counter=None):
1047    """Create a `SVStepCounterThread`.
1048
1049    Args:
1050      sv: A `Supervisor`.
1051      sess: A `Session`.
1052      step_counter: A `Tensor` holding the step counter. By defaults, it uses
1053        sv.global_step.
1054    """
1055    super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs)
1056    self._sv = sv
1057    self._sess = sess
1058    self._last_time = 0.0
1059    self._last_step = 0
1060    step_counter = sv.global_step if step_counter is None else step_counter
1061    self._step_counter = step_counter
1062    self._summary_tag = "%s/sec" % self._step_counter.op.name
1063
1064  def start_loop(self):
1065    self._last_time = time.time()
1066    self._last_step = training_util.global_step(
1067        self._sess, self._step_counter)
1068
1069  def run_loop(self):
1070    # Count the steps.
1071    current_step = training_util.global_step(self._sess, self._step_counter)
1072    added_steps = current_step - self._last_step
1073    self._last_step = current_step
1074    # Measure the elapsed time.
1075    current_time = time.time()
1076    elapsed_time = current_time - self._last_time
1077    self._last_time = current_time
1078    # Reports the number of steps done per second
1079    if elapsed_time > 0.:
1080      steps_per_sec = added_steps / elapsed_time
1081    else:
1082      steps_per_sec = float("inf")
1083    summary = Summary(value=[Summary.Value(tag=self._summary_tag,
1084                                           simple_value=steps_per_sec)])
1085    if self._sv.summary_writer:
1086      self._sv.summary_writer.add_summary(summary, current_step)
1087    logging.log_first_n(logging.INFO, "%s: %g", 10,
1088                        self._summary_tag, steps_per_sec)
1089
1090
1091class SVTimerCheckpointThread(coordinator.LooperThread):
1092  """A thread to checkpoint on a timer."""
1093
1094  def __init__(self, sv, sess):
1095    """Create a `SVTimerCheckpointThread`.
1096
1097    Args:
1098      sv: A `Supervisor`.
1099      sess: A `Session`.
1100    """
1101    super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs)
1102    self._sv = sv
1103    self._sess = sess
1104
1105  def run_loop(self):
1106    logging.info("Saving checkpoint to path %s", self._sv.save_path)
1107    self._sv.saver.save(self._sess, self._sv.save_path,
1108                        global_step=self._sv.global_step)
1109    if self._sv.summary_writer and self._sv.global_step is not None:
1110      current_step = training_util.global_step(self._sess, self._sv.global_step)
1111      self._sv.summary_writer.add_session_log(
1112          SessionLog(status=SessionLog.CHECKPOINT,
1113                     checkpoint_path=self._sv.save_path),
1114          current_step)
1115
1116
1117# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly.
1118setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session)
1119setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners)
1120setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services)
1121setattr(Supervisor, "Stop", Supervisor.stop)
1122setattr(Supervisor, "RequestStop", Supervisor.request_stop)
1123setattr(Supervisor, "Loop", Supervisor.loop)
1124setattr(Supervisor, "ShouldStop", Supervisor.should_stop)
1125setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception)
1126setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop)
1127setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed)
1128