• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and creates session."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import time
21
22import numpy as np
23
24from tensorflow.python.client import session
25from tensorflow.python.distribute import distribution_strategy_context
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.training import checkpoint_management
30from tensorflow.python.util.tf_export import tf_export
31
32
33def _maybe_name(obj):
34  """Returns object name if it has one, or a message otherwise.
35
36  This is useful for names that apper in error messages.
37  Args:
38    obj: Object to get the name of.
39  Returns:
40    name, "None", or a "no name" message.
41  """
42  if obj is None:
43    return "None"
44  elif hasattr(obj, "name"):
45    return obj.name
46  else:
47    return "<no name for %s>" % type(obj)
48
49
50def _restore_checkpoint_and_maybe_run_saved_model_initializers(
51    sess, saver, path):
52  """Restores checkpoint values and SavedModel initializers if found."""
53  # NOTE: All references to SavedModel refer to SavedModels loaded from the
54  # load_v2 API (which does not require the `sess` argument).
55
56  # If the graph contains resources loaded from a SavedModel, they are not
57  # restored when calling `saver.restore`. Thus, the SavedModel initializer must
58  # be called with `saver.restore` to properly initialize the model.
59
60  # The SavedModel init is stored in the "saved_model_initializers" collection.
61  # This collection is part of the MetaGraph's default_init_op, so it is already
62  # called by MonitoredSession as long as the saver doesn't restore any
63  # checkpoints from the working dir.
64  saved_model_init_ops = ops.get_collection("saved_model_initializers")
65  if saved_model_init_ops:
66    sess.run(saved_model_init_ops)
67
68  # The saver must be called *after* the SavedModel init, because the SavedModel
69  # init will restore the variables from the SavedModel variables directory.
70  # Initializing/restoring twice is not ideal but there's no other way to do it.
71  saver.restore(sess, path)
72
73
74@tf_export(v1=["train.SessionManager"])
75class SessionManager(object):
76  """Training helper that restores from checkpoint and creates session.
77
78  This class is a small wrapper that takes care of session creation and
79  checkpoint recovery. It also provides functions that to facilitate
80  coordination among multiple training threads or processes.
81
82  * Checkpointing trained variables as the training progresses.
83  * Initializing variables on startup, restoring them from the most recent
84    checkpoint after a crash, or wait for checkpoints to become available.
85
86  ### Usage:
87
88  ```python
89  with tf.Graph().as_default():
90     ...add operations to the graph...
91    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
92    sm = SessionManager()
93    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
94    # Use the session to train the graph.
95    while True:
96      sess.run(<my_train_op>)
97  ```
98
99  `prepare_session()` initializes or restores a model. It requires `init_op`
100  and `saver` as an argument.
101
102  A second process could wait for the model to be ready by doing the following:
103
104  ```python
105  with tf.Graph().as_default():
106     ...add operations to the graph...
107    # Create a SessionManager that will wait for the model to become ready.
108    sm = SessionManager()
109    sess = sm.wait_for_session(master)
110    # Use the session to train the graph.
111    while True:
112      sess.run(<my_train_op>)
113  ```
114
115  `wait_for_session()` waits for a model to be initialized by other processes.
116
117  """
118
119  def __init__(self,
120               local_init_op=None,
121               ready_op=None,
122               ready_for_local_init_op=None,
123               graph=None,
124               recovery_wait_secs=30,
125               local_init_run_options=None,
126               local_init_feed_dict=None):
127    """Creates a SessionManager.
128
129    The `local_init_op` is an `Operation` that is run always after a new session
130    was created. If `None`, this step is skipped.
131
132    The `ready_op` is an `Operation` used to check if the model is ready.  The
133    model is considered ready if that operation returns an empty 1D string
134    tensor. If the operation returns a non empty 1D string tensor, the elements
135    are concatenated and used to indicate to the user why the model is not
136    ready.
137
138    The `ready_for_local_init_op` is an `Operation` used to check if the model
139    is ready to run local_init_op.  The model is considered ready if that
140    operation returns an empty 1D string tensor. If the operation returns a non
141    empty 1D string tensor, the elements are concatenated and used to indicate
142    to the user why the model is not ready.
143
144    If `ready_op` is `None`, the model is not checked for readiness.
145
146    `recovery_wait_secs` is the number of seconds between checks that
147    the model is ready.  It is used by processes to wait for a model to
148    be initialized or restored.  Defaults to 30 seconds.
149
150    Args:
151      local_init_op: An `Operation` run immediately after session creation.
152         Usually used to initialize tables and local variables.
153      ready_op: An `Operation` to check if the model is initialized.
154      ready_for_local_init_op: An `Operation` to check if the model is ready
155         to run local_init_op.
156      graph: The `Graph` that the model will use.
157      recovery_wait_secs: Seconds between checks for the model to be ready.
158      local_init_run_options: RunOptions to be passed to session.run when
159        executing the local_init_op.
160      local_init_feed_dict: Optional session feed dictionary to use when running
161        the local_init_op.
162
163    Raises:
164      ValueError: If ready_for_local_init_op is not None but local_init_op is
165        None
166    """
167    # Sets default values of arguments.
168    if graph is None:
169      graph = ops.get_default_graph()
170    self._local_init_op = local_init_op
171    self._ready_op = ready_op
172    self._ready_for_local_init_op = ready_for_local_init_op
173    self._graph = graph
174    self._recovery_wait_secs = recovery_wait_secs
175    self._target = None
176    self._local_init_run_options = local_init_run_options
177    self._local_init_feed_dict = local_init_feed_dict
178    if ready_for_local_init_op is not None and local_init_op is None:
179      raise ValueError("If you pass a ready_for_local_init_op "
180                       "you must also pass a local_init_op "
181                       ", ready_for_local_init_op [%s]" %
182                       ready_for_local_init_op)
183
184  def _restore_checkpoint(self,
185                          master,
186                          saver=None,
187                          checkpoint_dir=None,
188                          checkpoint_filename_with_path=None,
189                          wait_for_checkpoint=False,
190                          max_wait_secs=7200,
191                          config=None):
192    """Creates a `Session`, and tries to restore a checkpoint.
193
194
195    Args:
196      master: `String` representation of the TensorFlow master to use.
197      saver: A `Saver` object used to restore a model.
198      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
199        dir will be used to restore.
200      checkpoint_filename_with_path: Full file name path to the checkpoint file.
201      wait_for_checkpoint: Whether to wait for checkpoint to become available.
202      max_wait_secs: Maximum time to wait for checkpoints to become available.
203      config: Optional `ConfigProto` proto used to configure the session.
204
205    Returns:
206      A pair (sess, is_restored) where 'is_restored' is `True` if
207      the session could be restored, `False` otherwise.
208
209    Raises:
210      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
211        set.
212    """
213    self._target = master
214
215    # This is required to so that we initialize the TPU device before
216    # restoring from checkpoint since we'll be placing variables on the device
217    # and TPUInitialize wipes out the memory of the device.
218    strategy = distribution_strategy_context.get_strategy()
219    if strategy and hasattr(strategy.extended,
220                            "_experimental_initialize_system"):
221      strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access
222
223    sess = session.Session(self._target, graph=self._graph, config=config)
224    if checkpoint_dir and checkpoint_filename_with_path:
225      raise ValueError("Can not provide both checkpoint_dir and "
226                       "checkpoint_filename_with_path.")
227    # If either saver or checkpoint_* is not specified, cannot restore. Just
228    # return.
229    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
230      return sess, False
231
232    if checkpoint_filename_with_path:
233      _restore_checkpoint_and_maybe_run_saved_model_initializers(
234          sess, saver, checkpoint_filename_with_path)
235      return sess, True
236
237    # Waits up until max_wait_secs for checkpoint to become available.
238    wait_time = 0
239    ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
240    while not ckpt or not ckpt.model_checkpoint_path:
241      if wait_for_checkpoint and wait_time < max_wait_secs:
242        logging.info("Waiting for checkpoint to be available.")
243        time.sleep(self._recovery_wait_secs)
244        wait_time += self._recovery_wait_secs
245        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
246      else:
247        return sess, False
248
249    # Loads the checkpoint.
250    _restore_checkpoint_and_maybe_run_saved_model_initializers(
251        sess, saver, ckpt.model_checkpoint_path)
252    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
253    return sess, True
254
255  def prepare_session(self,
256                      master,
257                      init_op=None,
258                      saver=None,
259                      checkpoint_dir=None,
260                      checkpoint_filename_with_path=None,
261                      wait_for_checkpoint=False,
262                      max_wait_secs=7200,
263                      config=None,
264                      init_feed_dict=None,
265                      init_fn=None):
266    """Creates a `Session`. Makes sure the model is ready to be used.
267
268    Creates a `Session` on 'master'. If a `saver` object is passed in, and
269    `checkpoint_dir` points to a directory containing valid checkpoint
270    files, then it will try to recover the model from checkpoint. If
271    no checkpoint files are available, and `wait_for_checkpoint` is
272    `True`, then the process would check every `recovery_wait_secs`,
273    up to `max_wait_secs`, for recovery to succeed.
274
275    If the model cannot be recovered successfully then it is initialized by
276    running the `init_op` and calling `init_fn` if they are provided.
277    The `local_init_op` is also run after init_op and init_fn, regardless of
278    whether the model was recovered successfully, but only if
279    `ready_for_local_init_op` passes.
280
281    If the model is recovered from a checkpoint it is assumed that all
282    global variables have been initialized, in particular neither `init_op`
283    nor `init_fn` will be executed.
284
285    It is an error if the model cannot be recovered and no `init_op`
286    or `init_fn` or `local_init_op` are passed.
287
288    Args:
289      master: `String` representation of the TensorFlow master to use.
290      init_op: Optional `Operation` used to initialize the model.
291      saver: A `Saver` object used to restore a model.
292      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
293        dir will be used to restore.
294      checkpoint_filename_with_path: Full file name path to the checkpoint file.
295      wait_for_checkpoint: Whether to wait for checkpoint to become available.
296      max_wait_secs: Maximum time to wait for checkpoints to become available.
297      config: Optional `ConfigProto` proto used to configure the session.
298      init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
299        values.  This feed dictionary is passed to the session `run()` call when
300        running the init op.
301      init_fn: Optional callable used to initialize the model. Called after the
302        optional `init_op` is called.  The callable must accept one argument,
303        the session being initialized.
304
305    Returns:
306      A `Session` object that can be used to drive the model.
307
308    Raises:
309      RuntimeError: If the model cannot be initialized or recovered.
310      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
311        set.
312    """
313
314    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
315        master,
316        saver,
317        checkpoint_dir=checkpoint_dir,
318        checkpoint_filename_with_path=checkpoint_filename_with_path,
319        wait_for_checkpoint=wait_for_checkpoint,
320        max_wait_secs=max_wait_secs,
321        config=config)
322    if not is_loaded_from_checkpoint:
323      if init_op is None and not init_fn and self._local_init_op is None:
324        raise RuntimeError("Model is not initialized and no init_op or "
325                           "init_fn or local_init_op was given")
326      if init_op is not None:
327        sess.run(init_op, feed_dict=init_feed_dict)
328      if init_fn:
329        init_fn(sess)
330
331    local_init_success, msg = self._try_run_local_init_op(sess)
332    if not local_init_success:
333      raise RuntimeError(
334          "Init operations did not make model ready for local_init.  "
335          "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
336                                                   init_fn,
337                                                   msg))
338
339    is_ready, msg = self._model_ready(sess)
340    if not is_ready:
341      raise RuntimeError(
342          "Init operations did not make model ready.  "
343          "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
344          (_maybe_name(init_op), init_fn, self._local_init_op, msg))
345    return sess
346
347  def recover_session(self,
348                      master,
349                      saver=None,
350                      checkpoint_dir=None,
351                      checkpoint_filename_with_path=None,
352                      wait_for_checkpoint=False,
353                      max_wait_secs=7200,
354                      config=None):
355    """Creates a `Session`, recovering if possible.
356
357    Creates a new session on 'master'.  If the session is not initialized
358    and can be recovered from a checkpoint, recover it.
359
360    Args:
361      master: `String` representation of the TensorFlow master to use.
362      saver: A `Saver` object used to restore a model.
363      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
364        dir will be used to restore.
365      checkpoint_filename_with_path: Full file name path to the checkpoint file.
366      wait_for_checkpoint: Whether to wait for checkpoint to become available.
367      max_wait_secs: Maximum time to wait for checkpoints to become available.
368      config: Optional `ConfigProto` proto used to configure the session.
369
370    Returns:
371      A pair (sess, initialized) where 'initialized' is `True` if
372      the session could be recovered and initialized, `False` otherwise.
373
374    Raises:
375      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
376        set.
377    """
378
379    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
380        master,
381        saver,
382        checkpoint_dir=checkpoint_dir,
383        checkpoint_filename_with_path=checkpoint_filename_with_path,
384        wait_for_checkpoint=wait_for_checkpoint,
385        max_wait_secs=max_wait_secs,
386        config=config)
387
388    # Always try to run local_init_op
389    local_init_success, msg = self._try_run_local_init_op(sess)
390
391    if not is_loaded_from_checkpoint:
392      # Do not need to run checks for readiness
393      return sess, False
394
395    restoring_file = checkpoint_dir or checkpoint_filename_with_path
396    if not local_init_success:
397      logging.info(
398          "Restoring model from %s did not make model ready for local init:"
399          " %s", restoring_file, msg)
400      return sess, False
401
402    is_ready, msg = self._model_ready(sess)
403    if not is_ready:
404      logging.info("Restoring model from %s did not make model ready: %s",
405                   restoring_file, msg)
406      return sess, False
407
408    logging.info("Restored model from %s", restoring_file)
409    return sess, is_loaded_from_checkpoint
410
411  def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
412    """Creates a new `Session` and waits for model to be ready.
413
414    Creates a new `Session` on 'master'.  Waits for the model to be
415    initialized or recovered from a checkpoint.  It's expected that
416    another thread or process will make the model ready, and that this
417    is intended to be used by threads/processes that participate in a
418    distributed training configuration where a different thread/process
419    is responsible for initializing or recovering the model being trained.
420
421    NB: The amount of time this method waits for the session is bounded
422    by max_wait_secs. By default, this function will wait indefinitely.
423
424    Args:
425      master: `String` representation of the TensorFlow master to use.
426      config: Optional ConfigProto proto used to configure the session.
427      max_wait_secs: Maximum time to wait for the session to become available.
428
429    Returns:
430      A `Session`. May be None if the operation exceeds the timeout
431      specified by config.operation_timeout_in_ms.
432
433    Raises:
434      tf.DeadlineExceededError: if the session is not available after
435        max_wait_secs.
436    """
437    self._target = master
438
439    if max_wait_secs is None:
440      max_wait_secs = float("Inf")
441    timer = _CountDownTimer(max_wait_secs)
442
443    while True:
444      sess = session.Session(self._target, graph=self._graph, config=config)
445      not_ready_msg = None
446      not_ready_local_msg = None
447      local_init_success, not_ready_local_msg = self._try_run_local_init_op(
448          sess)
449      if local_init_success:
450        # Successful if local_init_op is None, or ready_for_local_init_op passes
451        is_ready, not_ready_msg = self._model_ready(sess)
452        if is_ready:
453          return sess
454
455      self._safe_close(sess)
456
457      # Do we have enough time left to try again?
458      remaining_ms_after_wait = (
459          timer.secs_remaining() - self._recovery_wait_secs)
460      if remaining_ms_after_wait < 0:
461        raise errors.DeadlineExceededError(
462            None, None,
463            "Session was not ready after waiting %d secs." % (max_wait_secs,))
464
465      logging.info("Waiting for model to be ready.  "
466                   "Ready_for_local_init_op:  %s, ready: %s",
467                   not_ready_local_msg, not_ready_msg)
468      time.sleep(self._recovery_wait_secs)
469
470  def _safe_close(self, sess):
471    """Closes a session without raising an exception.
472
473    Just like sess.close() but ignores exceptions.
474
475    Args:
476      sess: A `Session`.
477    """
478    # pylint: disable=broad-except
479    try:
480      sess.close()
481    except Exception:
482      # Intentionally not logging to avoid user complaints that
483      # they get cryptic errors.  We really do not care that Close
484      # fails.
485      pass
486    # pylint: enable=broad-except
487
488  def _model_ready(self, sess):
489    """Checks if the model is ready or not.
490
491    Args:
492      sess: A `Session`.
493
494    Returns:
495      A tuple (is_ready, msg), where is_ready is True if ready and False
496      otherwise, and msg is `None` if the model is ready, a `String` with the
497      reason why it is not ready otherwise.
498    """
499    return _ready(self._ready_op, sess, "Model not ready")
500
501  def _model_ready_for_local_init(self, sess):
502    """Checks if the model is ready to run local_init_op.
503
504    Args:
505      sess: A `Session`.
506
507    Returns:
508      A tuple (is_ready, msg), where is_ready is True if ready to run
509      local_init_op and False otherwise, and msg is `None` if the model is
510      ready to run local_init_op, a `String` with the reason why it is not ready
511      otherwise.
512    """
513    return _ready(self._ready_for_local_init_op, sess,
514                  "Model not ready for local init")
515
516  def _try_run_local_init_op(self, sess):
517    """Tries to run _local_init_op, if not None, and is ready for local init.
518
519    Args:
520      sess: A `Session`.
521
522    Returns:
523      A tuple (is_successful, msg), where is_successful is True if
524      _local_init_op is None, or we ran _local_init_op, and False otherwise;
525      and msg is a `String` with the reason why the model was not ready to run
526      local init.
527    """
528    if self._local_init_op is not None:
529      is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
530      if is_ready_for_local_init:
531        logging.info("Running local_init_op.")
532        sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
533                 options=self._local_init_run_options)
534        logging.info("Done running local_init_op.")
535        return True, None
536      else:
537        return False, msg
538    return True, None
539
540
541def _ready(op, sess, msg):
542  """Checks if the model is ready or not, as determined by op.
543
544  Args:
545    op: An op, either _ready_op or _ready_for_local_init_op, which defines the
546      readiness of the model.
547    sess: A `Session`.
548    msg: A message to log to warning if not ready
549
550  Returns:
551    A tuple (is_ready, msg), where is_ready is True if ready and False
552    otherwise, and msg is `None` if the model is ready, a `String` with the
553    reason why it is not ready otherwise.
554  """
555  if op is None:
556    return True, None
557  else:
558    try:
559      ready_value = sess.run(op)
560      # The model is considered ready if ready_op returns an empty 1-D tensor.
561      # Also compare to `None` and dtype being int32 for backward
562      # compatibility.
563      if (ready_value is None or ready_value.dtype == np.int32 or
564          ready_value.size == 0):
565        return True, None
566      else:
567        # TODO(sherrym): If a custom ready_op returns other types of tensor,
568        # or strings other than variable names, this message could be
569        # confusing.
570        non_initialized_varnames = ", ".join(
571            [i.decode("utf-8") for i in ready_value])
572        return False, "Variables not initialized: " + non_initialized_varnames
573    except errors.FailedPreconditionError as e:
574      if "uninitialized" not in str(e):
575        logging.warning("%s : error [%s]", msg, str(e))
576        raise e
577      return False, str(e)
578
579
580class _CountDownTimer(object):
581
582  __slots__ = ["_start_time_secs", "_duration_secs"]
583
584  def __init__(self, duration_secs):
585    self._start_time_secs = time.time()
586    self._duration_secs = duration_secs
587
588  def secs_remaining(self):
589    diff = self._duration_secs - (time.time() - self._start_time_secs)
590    return max(0, diff)
591