• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and creates session."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import time
21import numpy as np
22
23from tensorflow.python.client import session
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.platform import tf_logging as logging
27from tensorflow.python.training import saver as saver_mod
28from tensorflow.python.util.tf_export import tf_export
29
30
31def _maybe_name(obj):
32  """Returns object name if it has one, or a message otherwise.
33
34  This is useful for names that apper in error messages.
35  Args:
36    obj: Object to get the name of.
37  Returns:
38    name, "None", or a "no name" message.
39  """
40  if obj is None:
41    return "None"
42  elif hasattr(obj, "name"):
43    return obj.name
44  else:
45    return "<no name for %s>" % type(obj)
46
47
48@tf_export("train.SessionManager")
49class SessionManager(object):
50  """Training helper that restores from checkpoint and creates session.
51
52  This class is a small wrapper that takes care of session creation and
53  checkpoint recovery. It also provides functions that to facilitate
54  coordination among multiple training threads or processes.
55
56  * Checkpointing trained variables as the training progresses.
57  * Initializing variables on startup, restoring them from the most recent
58    checkpoint after a crash, or wait for checkpoints to become available.
59
60  ### Usage:
61
62  ```python
63  with tf.Graph().as_default():
64     ...add operations to the graph...
65    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
66    sm = SessionManager()
67    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
68    # Use the session to train the graph.
69    while True:
70      sess.run(<my_train_op>)
71  ```
72
73  `prepare_session()` initializes or restores a model. It requires `init_op`
74  and `saver` as an argument.
75
76  A second process could wait for the model to be ready by doing the following:
77
78  ```python
79  with tf.Graph().as_default():
80     ...add operations to the graph...
81    # Create a SessionManager that will wait for the model to become ready.
82    sm = SessionManager()
83    sess = sm.wait_for_session(master)
84    # Use the session to train the graph.
85    while True:
86      sess.run(<my_train_op>)
87  ```
88
89  `wait_for_session()` waits for a model to be initialized by other processes.
90
91  """
92
93  def __init__(self,
94               local_init_op=None,
95               ready_op=None,
96               ready_for_local_init_op=None,
97               graph=None,
98               recovery_wait_secs=30):
99    """Creates a SessionManager.
100
101    The `local_init_op` is an `Operation` that is run always after a new session
102    was created. If `None`, this step is skipped.
103
104    The `ready_op` is an `Operation` used to check if the model is ready.  The
105    model is considered ready if that operation returns an empty 1D string
106    tensor. If the operation returns a non empty 1D string tensor, the elements
107    are concatenated and used to indicate to the user why the model is not
108    ready.
109
110    The `ready_for_local_init_op` is an `Operation` used to check if the model
111    is ready to run local_init_op.  The model is considered ready if that
112    operation returns an empty 1D string tensor. If the operation returns a non
113    empty 1D string tensor, the elements are concatenated and used to indicate
114    to the user why the model is not ready.
115
116    If `ready_op` is `None`, the model is not checked for readiness.
117
118    `recovery_wait_secs` is the number of seconds between checks that
119    the model is ready.  It is used by processes to wait for a model to
120    be initialized or restored.  Defaults to 30 seconds.
121
122    Args:
123      local_init_op: An `Operation` run immediately after session creation.
124         Usually used to initialize tables and local variables.
125      ready_op: An `Operation` to check if the model is initialized.
126      ready_for_local_init_op: An `Operation` to check if the model is ready
127         to run local_init_op.
128      graph: The `Graph` that the model will use.
129      recovery_wait_secs: Seconds between checks for the model to be ready.
130
131    Raises:
132      ValueError: If ready_for_local_init_op is not None but local_init_op is
133        None
134    """
135    # Sets default values of arguments.
136    if graph is None:
137      graph = ops.get_default_graph()
138    self._local_init_op = local_init_op
139    self._ready_op = ready_op
140    self._ready_for_local_init_op = ready_for_local_init_op
141    self._graph = graph
142    self._recovery_wait_secs = recovery_wait_secs
143    self._target = None
144    if ready_for_local_init_op is not None and local_init_op is None:
145      raise ValueError("If you pass a ready_for_local_init_op "
146                       "you must also pass a local_init_op "
147                       ", ready_for_local_init_op [%s]" %
148                       ready_for_local_init_op)
149
150  def _restore_checkpoint(self,
151                          master,
152                          saver=None,
153                          checkpoint_dir=None,
154                          checkpoint_filename_with_path=None,
155                          wait_for_checkpoint=False,
156                          max_wait_secs=7200,
157                          config=None):
158    """Creates a `Session`, and tries to restore a checkpoint.
159
160
161    Args:
162      master: `String` representation of the TensorFlow master to use.
163      saver: A `Saver` object used to restore a model.
164      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
165        dir will be used to restore.
166      checkpoint_filename_with_path: Full file name path to the checkpoint file.
167      wait_for_checkpoint: Whether to wait for checkpoint to become available.
168      max_wait_secs: Maximum time to wait for checkpoints to become available.
169      config: Optional `ConfigProto` proto used to configure the session.
170
171    Returns:
172      A pair (sess, is_restored) where 'is_restored' is `True` if
173      the session could be restored, `False` otherwise.
174
175    Raises:
176      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
177        set.
178    """
179    self._target = master
180    sess = session.Session(self._target, graph=self._graph, config=config)
181
182    if checkpoint_dir and checkpoint_filename_with_path:
183      raise ValueError("Can not provide both checkpoint_dir and "
184                       "checkpoint_filename_with_path.")
185    # If either saver or checkpoint_* is not specified, cannot restore. Just
186    # return.
187    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
188      return sess, False
189
190    if checkpoint_filename_with_path:
191      saver.restore(sess, checkpoint_filename_with_path)
192      return sess, True
193
194    # Waits up until max_wait_secs for checkpoint to become available.
195    wait_time = 0
196    ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
197    while not ckpt or not ckpt.model_checkpoint_path:
198      if wait_for_checkpoint and wait_time < max_wait_secs:
199        logging.info("Waiting for checkpoint to be available.")
200        time.sleep(self._recovery_wait_secs)
201        wait_time += self._recovery_wait_secs
202        ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
203      else:
204        return sess, False
205
206    # Loads the checkpoint.
207    saver.restore(sess, ckpt.model_checkpoint_path)
208    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
209    return sess, True
210
211  def prepare_session(self,
212                      master,
213                      init_op=None,
214                      saver=None,
215                      checkpoint_dir=None,
216                      checkpoint_filename_with_path=None,
217                      wait_for_checkpoint=False,
218                      max_wait_secs=7200,
219                      config=None,
220                      init_feed_dict=None,
221                      init_fn=None):
222    """Creates a `Session`. Makes sure the model is ready to be used.
223
224    Creates a `Session` on 'master'. If a `saver` object is passed in, and
225    `checkpoint_dir` points to a directory containing valid checkpoint
226    files, then it will try to recover the model from checkpoint. If
227    no checkpoint files are available, and `wait_for_checkpoint` is
228    `True`, then the process would check every `recovery_wait_secs`,
229    up to `max_wait_secs`, for recovery to succeed.
230
231    If the model cannot be recovered successfully then it is initialized by
232    either running the provided `init_op`, or calling the provided `init_fn`.
233    The local_init_op is also run after init_op and init_fn, regardless of
234    whether the model was recovered successfully, but only if
235    ready_for_local_init_op passes.
236
237    It is an error if the model cannot be recovered and no `init_op`
238    or `init_fn` or `local_init_op` are passed.
239
240    Args:
241      master: `String` representation of the TensorFlow master to use.
242      init_op: Optional `Operation` used to initialize the model.
243      saver: A `Saver` object used to restore a model.
244      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
245        dir will be used to restore.
246      checkpoint_filename_with_path: Full file name path to the checkpoint file.
247      wait_for_checkpoint: Whether to wait for checkpoint to become available.
248      max_wait_secs: Maximum time to wait for checkpoints to become available.
249      config: Optional `ConfigProto` proto used to configure the session.
250      init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
251        values.  This feed dictionary is passed to the session `run()` call when
252        running the init op.
253      init_fn: Optional callable used to initialize the model. Called after the
254        optional `init_op` is called.  The callable must accept one argument,
255        the session being initialized.
256
257    Returns:
258      A `Session` object that can be used to drive the model.
259
260    Raises:
261      RuntimeError: If the model cannot be initialized or recovered.
262
263    Raises:
264      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
265        set.
266    """
267
268    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
269        master,
270        saver,
271        checkpoint_dir=checkpoint_dir,
272        checkpoint_filename_with_path=checkpoint_filename_with_path,
273        wait_for_checkpoint=wait_for_checkpoint,
274        max_wait_secs=max_wait_secs,
275        config=config)
276    if not is_loaded_from_checkpoint:
277      if init_op is None and not init_fn and self._local_init_op is None:
278        raise RuntimeError("Model is not initialized and no init_op or "
279                           "init_fn or local_init_op was given")
280      if init_op is not None:
281        sess.run(init_op, feed_dict=init_feed_dict)
282      if init_fn:
283        init_fn(sess)
284
285    local_init_success, msg = self._try_run_local_init_op(sess)
286    if not local_init_success:
287      raise RuntimeError(
288          "Init operations did not make model ready for local_init.  "
289          "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
290                                                   init_fn,
291                                                   msg))
292
293    is_ready, msg = self._model_ready(sess)
294    if not is_ready:
295      raise RuntimeError(
296          "Init operations did not make model ready.  "
297          "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
298          (_maybe_name(init_op), init_fn, self._local_init_op, msg))
299    return sess
300
301  def recover_session(self,
302                      master,
303                      saver=None,
304                      checkpoint_dir=None,
305                      checkpoint_filename_with_path=None,
306                      wait_for_checkpoint=False,
307                      max_wait_secs=7200,
308                      config=None):
309    """Creates a `Session`, recovering if possible.
310
311    Creates a new session on 'master'.  If the session is not initialized
312    and can be recovered from a checkpoint, recover it.
313
314    Args:
315      master: `String` representation of the TensorFlow master to use.
316      saver: A `Saver` object used to restore a model.
317      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
318        dir will be used to restore.
319      checkpoint_filename_with_path: Full file name path to the checkpoint file.
320      wait_for_checkpoint: Whether to wait for checkpoint to become available.
321      max_wait_secs: Maximum time to wait for checkpoints to become available.
322      config: Optional `ConfigProto` proto used to configure the session.
323
324    Returns:
325      A pair (sess, initialized) where 'initialized' is `True` if
326      the session could be recovered and initialized, `False` otherwise.
327
328    Raises:
329      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
330        set.
331    """
332
333    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
334        master,
335        saver,
336        checkpoint_dir=checkpoint_dir,
337        checkpoint_filename_with_path=checkpoint_filename_with_path,
338        wait_for_checkpoint=wait_for_checkpoint,
339        max_wait_secs=max_wait_secs,
340        config=config)
341
342    # Always try to run local_init_op
343    local_init_success, msg = self._try_run_local_init_op(sess)
344
345    if not is_loaded_from_checkpoint:
346      # Do not need to run checks for readiness
347      return sess, False
348
349    restoring_file = checkpoint_dir or checkpoint_filename_with_path
350    if not local_init_success:
351      logging.info(
352          "Restoring model from %s did not make model ready for local init:"
353          " %s", restoring_file, msg)
354      return sess, False
355
356    is_ready, msg = self._model_ready(sess)
357    if not is_ready:
358      logging.info("Restoring model from %s did not make model ready: %s",
359                   restoring_file, msg)
360      return sess, False
361
362    logging.info("Restored model from %s", restoring_file)
363    return sess, is_loaded_from_checkpoint
364
365  def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
366    """Creates a new `Session` and waits for model to be ready.
367
368    Creates a new `Session` on 'master'.  Waits for the model to be
369    initialized or recovered from a checkpoint.  It's expected that
370    another thread or process will make the model ready, and that this
371    is intended to be used by threads/processes that participate in a
372    distributed training configuration where a different thread/process
373    is responsible for initializing or recovering the model being trained.
374
375    NB: The amount of time this method waits for the session is bounded
376    by max_wait_secs. By default, this function will wait indefinitely.
377
378    Args:
379      master: `String` representation of the TensorFlow master to use.
380      config: Optional ConfigProto proto used to configure the session.
381      max_wait_secs: Maximum time to wait for the session to become available.
382
383    Returns:
384      A `Session`. May be None if the operation exceeds the timeout
385      specified by config.operation_timeout_in_ms.
386
387    Raises:
388      tf.DeadlineExceededError: if the session is not available after
389        max_wait_secs.
390    """
391    self._target = master
392
393    if max_wait_secs is None:
394      max_wait_secs = float("Inf")
395    timer = _CountDownTimer(max_wait_secs)
396
397    while True:
398      sess = session.Session(self._target, graph=self._graph, config=config)
399      not_ready_msg = None
400      not_ready_local_msg = None
401      local_init_success, not_ready_local_msg = self._try_run_local_init_op(
402          sess)
403      if local_init_success:
404        # Successful if local_init_op is None, or ready_for_local_init_op passes
405        is_ready, not_ready_msg = self._model_ready(sess)
406        if is_ready:
407          return sess
408
409      self._safe_close(sess)
410
411      # Do we have enough time left to try again?
412      remaining_ms_after_wait = (
413          timer.secs_remaining() - self._recovery_wait_secs)
414      if remaining_ms_after_wait < 0:
415        raise errors.DeadlineExceededError(
416            None, None,
417            "Session was not ready after waiting %d secs." % (max_wait_secs,))
418
419      logging.info("Waiting for model to be ready.  "
420                   "Ready_for_local_init_op:  %s, ready: %s",
421                   not_ready_local_msg, not_ready_msg)
422      time.sleep(self._recovery_wait_secs)
423
424  def _safe_close(self, sess):
425    """Closes a session without raising an exception.
426
427    Just like sess.close() but ignores exceptions.
428
429    Args:
430      sess: A `Session`.
431    """
432    # pylint: disable=broad-except
433    try:
434      sess.close()
435    except Exception:
436      # Intentionally not logging to avoid user complaints that
437      # they get cryptic errors.  We really do not care that Close
438      # fails.
439      pass
440    # pylint: enable=broad-except
441
442  def _model_ready(self, sess):
443    """Checks if the model is ready or not.
444
445    Args:
446      sess: A `Session`.
447
448    Returns:
449      A tuple (is_ready, msg), where is_ready is True if ready and False
450      otherwise, and msg is `None` if the model is ready, a `String` with the
451      reason why it is not ready otherwise.
452    """
453    return _ready(self._ready_op, sess, "Model not ready")
454
455  def _model_ready_for_local_init(self, sess):
456    """Checks if the model is ready to run local_init_op.
457
458    Args:
459      sess: A `Session`.
460
461    Returns:
462      A tuple (is_ready, msg), where is_ready is True if ready to run
463      local_init_op and False otherwise, and msg is `None` if the model is
464      ready to run local_init_op, a `String` with the reason why it is not ready
465      otherwise.
466    """
467    return _ready(self._ready_for_local_init_op, sess,
468                  "Model not ready for local init")
469
470  def _try_run_local_init_op(self, sess):
471    """Tries to run _local_init_op, if not None, and is ready for local init.
472
473    Args:
474      sess: A `Session`.
475
476    Returns:
477      A tuple (is_successful, msg), where is_successful is True if
478      _local_init_op is None, or we ran _local_init_op, and False otherwise;
479      and msg is a `String` with the reason why the model was not ready to run
480      local init.
481    """
482    if self._local_init_op is not None:
483      is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
484      if is_ready_for_local_init:
485        logging.info("Running local_init_op.")
486        sess.run(self._local_init_op)
487        logging.info("Done running local_init_op.")
488        return True, None
489      else:
490        return False, msg
491    return True, None
492
493
494def _ready(op, sess, msg):
495  """Checks if the model is ready or not, as determined by op.
496
497  Args:
498    op: An op, either _ready_op or _ready_for_local_init_op, which defines the
499      readiness of the model.
500    sess: A `Session`.
501    msg: A message to log to warning if not ready
502
503  Returns:
504    A tuple (is_ready, msg), where is_ready is True if ready and False
505    otherwise, and msg is `None` if the model is ready, a `String` with the
506    reason why it is not ready otherwise.
507  """
508  if op is None:
509    return True, None
510  else:
511    try:
512      ready_value = sess.run(op)
513      # The model is considered ready if ready_op returns an empty 1-D tensor.
514      # Also compare to `None` and dtype being int32 for backward
515      # compatibility.
516      if (ready_value is None or ready_value.dtype == np.int32 or
517          ready_value.size == 0):
518        return True, None
519      else:
520        # TODO(sherrym): If a custom ready_op returns other types of tensor,
521        # or strings other than variable names, this message could be
522        # confusing.
523        non_initialized_varnames = ", ".join(
524            [i.decode("utf-8") for i in ready_value])
525        return False, "Variables not initialized: " + non_initialized_varnames
526    except errors.FailedPreconditionError as e:
527      if "uninitialized" not in str(e):
528        logging.warning("%s : error [%s]", msg, str(e))
529        raise e
530      return False, str(e)
531
532
533class _CountDownTimer(object):
534
535  def __init__(self, duration_secs):
536    self._start_time_secs = time.time()
537    self._duration_secs = duration_secs
538
539  def secs_remaining(self):
540    diff = self._duration_secs - (time.time() - self._start_time_secs)
541    return max(0, diff)
542