• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import os.path
24import re
25import time
26
27from google.protobuf import text_format
28
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.eager import context
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.lib.io import file_io
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.training import training_util
37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
38from tensorflow.python.util import compat
39from tensorflow.python.util import deprecation
40from tensorflow.python.util.tf_export import tf_export
41
42
43def _GetCheckpointFilename(save_dir, latest_filename):
44  """Returns a filename for storing the CheckpointState.
45
46  Args:
47    save_dir: The directory for saving and restoring checkpoints.
48    latest_filename: Name of the file in 'save_dir' that is used
49      to store the CheckpointState.
50
51  Returns:
52    The path of the file that contains the CheckpointState proto.
53  """
54  if latest_filename is None:
55    latest_filename = "checkpoint"
56  return os.path.join(save_dir, latest_filename)
57
58
59@tf_export(v1=["train.generate_checkpoint_state_proto"])
60def generate_checkpoint_state_proto(save_dir,
61                                    model_checkpoint_path,
62                                    all_model_checkpoint_paths=None,
63                                    all_model_checkpoint_timestamps=None,
64                                    last_preserved_timestamp=None):
65  """Generates a checkpoint state proto.
66
67  Args:
68    save_dir: Directory where the model was saved.
69    model_checkpoint_path: The checkpoint file.
70    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
71      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
72      the last element must be equal to model_checkpoint_path.  These paths
73      are also saved in the CheckpointState proto.
74    all_model_checkpoint_timestamps: A list of floats, indicating the number of
75      seconds since the Epoch when each checkpoint was generated.
76    last_preserved_timestamp: A float, indicating the number of seconds since
77      the Epoch when the last preserved checkpoint was written, e.g. due to a
78      `keep_checkpoint_every_n_hours` parameter (see
79      `tf.train.CheckpointManager` for an implementation).
80  Returns:
81    CheckpointState proto with model_checkpoint_path and
82    all_model_checkpoint_paths updated to either absolute paths or
83    relative paths to the current save_dir.
84
85  Raises:
86    ValueError: If `all_model_checkpoint_timestamps` was provided but its length
87      does not match `all_model_checkpoint_paths`.
88  """
89  if all_model_checkpoint_paths is None:
90    all_model_checkpoint_paths = []
91
92  if (not all_model_checkpoint_paths or
93      all_model_checkpoint_paths[-1] != model_checkpoint_path):
94    logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
95                 model_checkpoint_path)
96    all_model_checkpoint_paths.append(model_checkpoint_path)
97
98  if (all_model_checkpoint_timestamps
99      and (len(all_model_checkpoint_timestamps)
100           != len(all_model_checkpoint_paths))):
101    raise ValueError(
102        ("Checkpoint timestamps, if provided, must match checkpoint paths (got "
103         "paths %s and timestamps %s)")
104        % (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
105
106  # Relative paths need to be rewritten to be relative to the "save_dir"
107  # if model_checkpoint_path already contains "save_dir".
108  if not os.path.isabs(save_dir):
109    if not os.path.isabs(model_checkpoint_path):
110      model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
111    for i, p in enumerate(all_model_checkpoint_paths):
112      if not os.path.isabs(p):
113        all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
114
115  coord_checkpoint_proto = CheckpointState(
116      model_checkpoint_path=model_checkpoint_path,
117      all_model_checkpoint_paths=all_model_checkpoint_paths,
118      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
119      last_preserved_timestamp=last_preserved_timestamp)
120
121  return coord_checkpoint_proto
122
123
124@deprecation.deprecated(
125    date=None,
126    instructions=("Use `tf.train.CheckpointManager` to manage checkpoints "
127                  "rather than manually editing the Checkpoint proto."))
128@tf_export(v1=["train.update_checkpoint_state"])
129def update_checkpoint_state(save_dir,
130                            model_checkpoint_path,
131                            all_model_checkpoint_paths=None,
132                            latest_filename=None,
133                            all_model_checkpoint_timestamps=None,
134                            last_preserved_timestamp=None):
135  """Updates the content of the 'checkpoint' file.
136
137  This updates the checkpoint file containing a CheckpointState
138  proto.
139
140  Args:
141    save_dir: Directory where the model was saved.
142    model_checkpoint_path: The checkpoint file.
143    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
144      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
145      the last element must be equal to model_checkpoint_path.  These paths
146      are also saved in the CheckpointState proto.
147    latest_filename: Optional name of the checkpoint file.  Default to
148      'checkpoint'.
149    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
150      seconds since the Epoch) indicating when the checkpoints in
151      `all_model_checkpoint_paths` were created.
152    last_preserved_timestamp: A float, indicating the number of seconds since
153      the Epoch when the last preserved checkpoint was written, e.g. due to a
154      `keep_checkpoint_every_n_hours` parameter (see
155      `tf.train.CheckpointManager` for an implementation).
156  Raises:
157    RuntimeError: If any of the model checkpoint paths conflict with the file
158      containing CheckpointSate.
159  """
160  update_checkpoint_state_internal(
161      save_dir=save_dir,
162      model_checkpoint_path=model_checkpoint_path,
163      all_model_checkpoint_paths=all_model_checkpoint_paths,
164      latest_filename=latest_filename,
165      save_relative_paths=False,
166      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
167      last_preserved_timestamp=last_preserved_timestamp)
168
169
170def update_checkpoint_state_internal(save_dir,
171                                     model_checkpoint_path,
172                                     all_model_checkpoint_paths=None,
173                                     latest_filename=None,
174                                     save_relative_paths=False,
175                                     all_model_checkpoint_timestamps=None,
176                                     last_preserved_timestamp=None):
177  """Updates the content of the 'checkpoint' file.
178
179  This updates the checkpoint file containing a CheckpointState
180  proto.
181
182  Args:
183    save_dir: Directory where the model was saved.
184    model_checkpoint_path: The checkpoint file.
185    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
186      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
187      the last element must be equal to model_checkpoint_path.  These paths
188      are also saved in the CheckpointState proto.
189    latest_filename: Optional name of the checkpoint file.  Default to
190      'checkpoint'.
191    save_relative_paths: If `True`, will write relative paths to the checkpoint
192      state file.
193    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
194      seconds since the Epoch) indicating when the checkpoints in
195      `all_model_checkpoint_paths` were created.
196    last_preserved_timestamp: A float, indicating the number of seconds since
197      the Epoch when the last preserved checkpoint was written, e.g. due to a
198      `keep_checkpoint_every_n_hours` parameter (see
199      `tf.train.CheckpointManager` for an implementation).
200
201  Raises:
202    RuntimeError: If any of the model checkpoint paths conflict with the file
203      containing CheckpointSate.
204  """
205  # Writes the "checkpoint" file for the coordinator for later restoration.
206  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
207  if save_relative_paths:
208    if os.path.isabs(model_checkpoint_path):
209      rel_model_checkpoint_path = os.path.relpath(
210          model_checkpoint_path, save_dir)
211    else:
212      rel_model_checkpoint_path = model_checkpoint_path
213    rel_all_model_checkpoint_paths = []
214    for p in all_model_checkpoint_paths:
215      if os.path.isabs(p):
216        rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
217      else:
218        rel_all_model_checkpoint_paths.append(p)
219    ckpt = generate_checkpoint_state_proto(
220        save_dir,
221        rel_model_checkpoint_path,
222        all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
223        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
224        last_preserved_timestamp=last_preserved_timestamp)
225  else:
226    ckpt = generate_checkpoint_state_proto(
227        save_dir,
228        model_checkpoint_path,
229        all_model_checkpoint_paths=all_model_checkpoint_paths,
230        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
231        last_preserved_timestamp=last_preserved_timestamp)
232
233  if coord_checkpoint_filename == ckpt.model_checkpoint_path:
234    raise RuntimeError("Save path '%s' conflicts with path used for "
235                       "checkpoint state.  Please use a different save path." %
236                       model_checkpoint_path)
237
238  # Preventing potential read/write race condition by *atomically* writing to a
239  # file.
240  file_io.atomic_write_string_to_file(coord_checkpoint_filename,
241                                      text_format.MessageToString(ckpt))
242
243
244@tf_export("train.get_checkpoint_state")
245def get_checkpoint_state(checkpoint_dir, latest_filename=None):
246  """Returns CheckpointState proto from the "checkpoint" file.
247
248  If the "checkpoint" file contains a valid CheckpointState
249  proto, returns it.
250
251  Args:
252    checkpoint_dir: The directory of checkpoints.
253    latest_filename: Optional name of the checkpoint file.  Default to
254      'checkpoint'.
255
256  Returns:
257    A CheckpointState if the state was available, None
258    otherwise.
259
260  Raises:
261    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
262  """
263  ckpt = None
264  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
265                                                     latest_filename)
266  f = None
267  try:
268    # Check that the file exists before opening it to avoid
269    # many lines of errors from colossus in the logs.
270    if file_io.file_exists(coord_checkpoint_filename):
271      file_content = file_io.read_file_to_string(
272          coord_checkpoint_filename)
273      ckpt = CheckpointState()
274      text_format.Merge(file_content, ckpt)
275      if not ckpt.model_checkpoint_path:
276        raise ValueError("Invalid checkpoint state loaded from "
277                         + checkpoint_dir)
278      # For relative model_checkpoint_path and all_model_checkpoint_paths,
279      # prepend checkpoint_dir.
280      if not os.path.isabs(ckpt.model_checkpoint_path):
281        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
282                                                  ckpt.model_checkpoint_path)
283      for i, p in enumerate(ckpt.all_model_checkpoint_paths):
284        if not os.path.isabs(p):
285          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
286  except errors.OpError as e:
287    # It's ok if the file cannot be read
288    logging.warning("%s: %s", type(e).__name__, e)
289    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
290    return None
291  except text_format.ParseError as e:
292    logging.warning("%s: %s", type(e).__name__, e)
293    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
294    return None
295  finally:
296    if f:
297      f.close()
298  return ckpt
299
300
301def _prefix_to_checkpoint_path(prefix, format_version):
302  """Returns the pathname of a checkpoint file, given the checkpoint prefix.
303
304  For V1 checkpoint, simply returns the prefix itself (the data file).  For V2,
305  returns the pathname to the index file.
306
307  Args:
308    prefix: a string, the prefix of a checkpoint.
309    format_version: the checkpoint format version that corresponds to the
310      prefix.
311  Returns:
312    The pathname of a checkpoint file, taking into account the checkpoint
313      format version.
314  """
315  if format_version == saver_pb2.SaverDef.V2:
316    return prefix + ".index"  # The index file identifies a checkpoint.
317  return prefix  # Just the data file.
318
319
320@tf_export("train.latest_checkpoint")
321def latest_checkpoint(checkpoint_dir, latest_filename=None):
322  """Finds the filename of latest saved checkpoint file.
323
324  Gets the checkpoint state given the provided checkpoint_dir and looks for a
325  corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path.
326  The latest_filename argument is only applicable if you are saving checkpoint
327  using `v1.Saver.save`
328
329
330  See the [Training Checkpoints
331  Guide](https://www.tensorflow.org/guide/checkpoint) for more details and
332  examples.`
333
334  Args:
335    checkpoint_dir: Directory where the variables were saved.
336    latest_filename: Optional name for the protocol buffer file that
337      contains the list of most recent checkpoint filenames.
338      See the corresponding argument to `v1.Saver.save`.
339
340  Returns:
341    The full path to the latest checkpoint or `None` if no checkpoint was found.
342  """
343  # Pick the latest checkpoint based on checkpoint state.
344  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
345  if ckpt and ckpt.model_checkpoint_path:
346    # Look for either a V2 path or a V1 path, with priority for V2.
347    v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
348                                         saver_pb2.SaverDef.V2)
349    v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
350                                         saver_pb2.SaverDef.V1)
351    if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
352        v1_path):
353      return ckpt.model_checkpoint_path
354    else:
355      logging.error("Couldn't match files for checkpoint %s",
356                    ckpt.model_checkpoint_path)
357  return None
358
359
360def checkpoint_exists_internal(checkpoint_prefix):
361  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
362
363  This is an internal function to check if a checkpoint exists,
364  since it takes into account the naming difference between V1 and V2 formats.
365
366  Args:
367    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
368      priority.  Typically the result of `Saver.save()` or that of
369      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
370      V1/V2.
371  Returns:
372    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
373  """
374  pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
375                                        saver_pb2.SaverDef.V2)
376  if file_io.get_matching_files(pathname):
377    return True
378  elif file_io.get_matching_files(checkpoint_prefix):
379    return True
380  else:
381    return False
382
383
384@deprecation.deprecated(
385    date=None,
386    instructions="Use standard file APIs to check for files with this prefix.")
387@tf_export(v1=["train.checkpoint_exists"])
388def checkpoint_exists(checkpoint_prefix):
389  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
390
391  This is the recommended way to check if a checkpoint exists, since it takes
392  into account the naming difference between V1 and V2 formats.
393
394  Args:
395    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
396      priority.  Typically the result of `Saver.save()` or that of
397      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
398      V1/V2.
399
400  Returns:
401    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
402  """
403  return checkpoint_exists_internal(checkpoint_prefix)
404
405
406@deprecation.deprecated(
407    date=None,
408    instructions="Use standard file utilities to get mtimes.")
409@tf_export(v1=["train.get_checkpoint_mtimes"])
410def get_checkpoint_mtimes(checkpoint_prefixes):
411  """Returns the mtimes (modification timestamps) of the checkpoints.
412
413  Globs for the checkpoints pointed to by `checkpoint_prefixes`.  If the files
414  exist, collect their mtime.  Both V2 and V1 checkpoints are considered, in
415  that priority.
416
417  This is the recommended way to get the mtimes, since it takes into account
418  the naming difference between V1 and V2 formats.
419
420  Note: If not all checkpoints exist, the length of the returned mtimes list
421  will be smaller than the length of `checkpoint_prefixes` list, so mapping
422  checkpoints to corresponding mtimes will not be possible.
423
424  Args:
425    checkpoint_prefixes: a list of checkpoint paths, typically the results of
426      `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
427      sharded/non-sharded or V1/V2.
428  Returns:
429    A list of mtimes (in microseconds) of the found checkpoints.
430  """
431  mtimes = []
432
433  def match_maybe_append(pathname):
434    fnames = file_io.get_matching_files(pathname)
435    if fnames:
436      mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
437      return True
438    return False
439
440  for checkpoint_prefix in checkpoint_prefixes:
441    # Tries V2's metadata file first.
442    pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
443                                          saver_pb2.SaverDef.V2)
444    if match_maybe_append(pathname):
445      continue
446    # Otherwise, tries V1, where the prefix is the complete pathname.
447    match_maybe_append(checkpoint_prefix)
448
449  return mtimes
450
451
452@deprecation.deprecated(
453    date=None,
454    instructions="Use standard file APIs to delete files with this prefix.")
455@tf_export(v1=["train.remove_checkpoint"])
456def remove_checkpoint(checkpoint_prefix,
457                      checkpoint_format_version=saver_pb2.SaverDef.V2,
458                      meta_graph_suffix="meta"):
459  """Removes a checkpoint given by `checkpoint_prefix`.
460
461  Args:
462    checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
463      of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
464      sharded/non-sharded or V1/V2.
465    checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
466      `SaverDef.V2`.
467    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
468  """
469  _delete_file_if_exists(
470      meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
471  if checkpoint_format_version == saver_pb2.SaverDef.V2:
472    # V2 has a metadata file and some data files.
473    _delete_file_if_exists(checkpoint_prefix + ".index")
474    _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
475  else:
476    # V1, Legacy.  Exact match on the data file.
477    _delete_file_if_exists(checkpoint_prefix)
478
479
480def _delete_file_if_exists(filespec):
481  """Deletes files matching `filespec`."""
482  for pathname in file_io.get_matching_files(filespec):
483    file_io.delete_file(pathname)
484
485
486def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
487  """Returns the meta graph filename.
488
489  Args:
490    checkpoint_filename: Name of the checkpoint file.
491    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
492
493  Returns:
494    MetaGraph file name.
495  """
496  # If the checkpoint_filename is sharded, the checkpoint_filename could
497  # be of format model.ckpt-step#-?????-of-shard#. For example,
498  # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
499  basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
500  suffixed_filename = ".".join([basename, meta_graph_suffix])
501  return suffixed_filename
502
503
504# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
505@tf_export("train.CheckpointManager")
506class CheckpointManager(object):
507  """Deletes old checkpoints.
508
509  Example usage:
510
511  ```python
512  import tensorflow as tf
513  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
514  manager = tf.train.CheckpointManager(
515      checkpoint, directory="/tmp/model", max_to_keep=5)
516  status = checkpoint.restore(manager.latest_checkpoint)
517  while True:
518    # train
519    manager.save()
520  ```
521
522  `CheckpointManager` preserves its own state across instantiations (see the
523  `__init__` documentation for details). Only one should be active in a
524  particular directory at a time.
525  """
526
527  def __init__(self,
528               checkpoint,
529               directory,
530               max_to_keep,
531               keep_checkpoint_every_n_hours=None,
532               checkpoint_name="ckpt"):
533    """Configure a `CheckpointManager` for use in `directory`.
534
535    If a `CheckpointManager` was previously used in `directory`, its
536    state will be restored. This includes the list of managed checkpoints and
537    the timestamp bookkeeping necessary to support
538    `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
539    will be the same as the previous `CheckpointManager`, including cleaning up
540    existing checkpoints if appropriate.
541
542    Checkpoints are only considered for deletion just after a new checkpoint has
543    been added. At that point, `max_to_keep` checkpoints will remain in an
544    "active set". Once a checkpoint is preserved by
545    `keep_checkpoint_every_n_hours` it will not be deleted by this
546    `CheckpointManager` or any future `CheckpointManager` instantiated in
547    `directory` (regardless of the new setting of
548    `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
549    active set may be deleted by this `CheckpointManager` or a future
550    `CheckpointManager` instantiated in `directory` (subject to its
551    `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
552
553    Args:
554      checkpoint: The `tf.train.Checkpoint` instance to save and manage
555        checkpoints for.
556      directory: The path to a directory in which to write checkpoints. A
557        special file named "checkpoint" is also written to this directory (in a
558        human-readable text format) which contains the state of the
559        `CheckpointManager`.
560      max_to_keep: An integer, the number of checkpoints to keep. Unless
561        preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
562        deleted from the active set, oldest first, until only `max_to_keep`
563        checkpoints remain. If `None`, no checkpoints are deleted and everything
564        stays in the active set. Note that `max_to_keep=None` will keep all
565        checkpoint paths in memory and in the checkpoint state protocol buffer
566        on disk.
567      keep_checkpoint_every_n_hours: Upon removal from the active set, a
568        checkpoint will be preserved if it has been at least
569        `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
570        default setting of `None` does not preserve any checkpoints in this way.
571      checkpoint_name: Custom name for the checkpoint file.
572
573    Raises:
574      ValueError: If `max_to_keep` is not a positive integer.
575    """
576    self._checkpoint = checkpoint
577    self._save_counter_assign = None
578    if max_to_keep is not None and max_to_keep <= 0:
579      raise ValueError(
580          ("Expected a positive integer or `None` for `max_to_keep`, "
581           "got %d.")
582          % (max_to_keep,))
583    self._max_to_keep = max_to_keep
584    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
585    self._directory = directory
586    self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
587    recovered_state = get_checkpoint_state(directory)
588    current_clock = time.time()
589    self._maybe_delete = collections.OrderedDict()
590    if recovered_state is None:
591      self._latest_checkpoint = None
592      # Set the clock back slightly to avoid race conditions when quckly
593      # re-creating a CheckpointManager.
594      self._last_preserved_timestamp = current_clock - 1.
595    else:
596      self._latest_checkpoint = recovered_state.model_checkpoint_path
597      self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
598      if current_clock < self._last_preserved_timestamp:
599        # Time seems to have reversed itself. In addition to this warning, we'll
600        # min() saved checkpoint timestamps with the current time to ensure that
601        # old checkpoints don't get deleted accidentally.
602        logging.warning(
603            ("time.time() returned a value %f seconds behind the last "
604             "preserved checkpoint timestamp.")
605            % (self._last_preserved_timestamp - current_clock,))
606        self._last_preserved_timestamp = current_clock
607      all_timestamps = recovered_state.all_model_checkpoint_timestamps
608      all_paths = recovered_state.all_model_checkpoint_paths
609      del recovered_state  # Uses modified values from now on
610      if not all_timestamps:
611        all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
612
613      for filename, timestamp in zip(all_paths, all_timestamps):
614        timestamp = min(timestamp, current_clock)
615        if timestamp > self._last_preserved_timestamp:
616          self._maybe_delete[filename] = timestamp
617
618  @property
619  def directory(self):
620    return self._directory
621
622  @property
623  def latest_checkpoint(self):
624    """The prefix of the most recent checkpoint in `directory`.
625
626    Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
627    the constructor argument to `CheckpointManager`.
628
629    Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
630
631    Returns:
632      The checkpoint prefix. If there are no checkpoints, returns `None`.
633    """
634    return self._latest_checkpoint
635
636  @property
637  def checkpoints(self):
638    """A list of managed checkpoints.
639
640    Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
641    show up in this list (to avoid ever-growing filename lists).
642
643    Returns:
644      A list of filenames, sorted from oldest to newest.
645    """
646    return list(self._maybe_delete.keys())
647
648  def _sweep(self):
649    """Deletes or preserves managed checkpoints."""
650    if not self._max_to_keep:
651      # Does not update self._last_preserved_timestamp, since everything is kept
652      # in the active set.
653      return
654    while len(self._maybe_delete) > self._max_to_keep:
655      filename, timestamp = self._maybe_delete.popitem(last=False)
656      # Even if we're keeping this checkpoint due to
657      # keep_checkpoint_every_n_hours, we won't reference it to avoid
658      # infinitely-growing CheckpointState protos.
659      if (self._keep_checkpoint_every_n_hours
660          and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
661               >= self._last_preserved_timestamp)):
662        self._last_preserved_timestamp = timestamp
663        continue
664      _delete_file_if_exists(filename + ".index")
665      _delete_file_if_exists(filename + ".data-?????-of-?????")
666
667  def _record_state(self):
668    """Saves the `CheckpointManager`'s state in `directory`."""
669    filenames, timestamps = zip(*self._maybe_delete.items())
670    update_checkpoint_state_internal(
671        self._directory,
672        model_checkpoint_path=self.latest_checkpoint,
673        all_model_checkpoint_paths=filenames,
674        all_model_checkpoint_timestamps=timestamps,
675        last_preserved_timestamp=self._last_preserved_timestamp,
676        save_relative_paths=True)
677
678  @property
679  def _prefix(self):
680    """A common prefix for all checkpoints saved with this manager.
681
682    For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
683    `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
684    numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
685    checkpoint has several associated files
686    (e.g. `"/tmp/tf-model/ckpt-2.index"`).
687
688    Returns:
689      A string prefix.
690    """
691    return self._checkpoint_prefix
692
693  def save(self, checkpoint_number=None):
694    """Creates a new checkpoint and manages it.
695
696    Args:
697      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
698        `Tensor`, used to number the checkpoint. If `None` (default),
699        checkpoints are numbered using `checkpoint.save_counter`. Even if
700        `checkpoint_number` is provided, `save_counter` is still incremented. A
701        user-provided `checkpoint_number` is not incremented even if it is a
702        `Variable`.
703
704    Returns:
705      The path to the new checkpoint. It is also recorded in the `checkpoints`
706      and `latest_checkpoint` properties.
707    """
708    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
709    # slightly with a custom numbering option.
710    if context.executing_eagerly():
711      save_counter = self._checkpoint.save_counter
712      save_counter.assign_add(1)
713      session = None
714    else:
715      session = ops.get_default_session()
716
717      def _initializing_creator(next_creator, **kwargs):
718        """Initialize the save counter if it has been newly created."""
719        v = next_creator(**kwargs)
720        session.run(v.initializer)
721        return v
722
723      with variable_scope.variable_creator_scope(_initializing_creator):
724        save_counter = self._checkpoint.save_counter
725      if self._save_counter_assign is None:
726        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
727      session.run(self._save_counter_assign)
728    if checkpoint_number is None:
729      checkpoint_number = save_counter
730    if not isinstance(checkpoint_number, compat.integral_types):
731      checkpoint_number = training_util.global_step(
732          sess=session, global_step_tensor=checkpoint_number)
733    prefix = "%s-%d" % (self._prefix, checkpoint_number)
734    save_path = self._checkpoint.write(prefix)
735    timestamp = time.time()
736    # If this is an overwritten checkpoint we were previously tracking, delete
737    # and reinsert it to make sure it goes to the end of the queue.
738    if save_path in self._maybe_delete:
739      del self._maybe_delete[save_path]
740    self._maybe_delete[save_path] = timestamp
741    self._latest_checkpoint = save_path
742    # Before deleting anything we update the Checkpoint proto with the new
743    # checkpoint. We'll go back and correct it after cleaning up old files, but
744    # a preemption while deleting will be more likely to see the new checkpoint
745    # this way.
746    self._record_state()
747    self._sweep()
748    # Write out the Checkpoint proto a second time, now without the deleted
749    # checkpoints.
750    self._record_state()
751    return save_path
752