• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import os.path
24import re
25import time
26
27from google.protobuf import text_format
28
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.eager import context
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.lib.io import file_io
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.training import training_util
37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
38from tensorflow.python.util import compat
39from tensorflow.python.util import deprecation
40from tensorflow.python.util.tf_export import tf_export
41
42
43def _GetCheckpointFilename(save_dir, latest_filename):
44  """Returns a filename for storing the CheckpointState.
45
46  Args:
47    save_dir: The directory for saving and restoring checkpoints.
48    latest_filename: Name of the file in 'save_dir' that is used
49      to store the CheckpointState.
50
51  Returns:
52    The path of the file that contains the CheckpointState proto.
53  """
54  if latest_filename is None:
55    latest_filename = "checkpoint"
56  return os.path.join(save_dir, latest_filename)
57
58
59@tf_export(v1=["train.generate_checkpoint_state_proto"])
60def generate_checkpoint_state_proto(save_dir,
61                                    model_checkpoint_path,
62                                    all_model_checkpoint_paths=None,
63                                    all_model_checkpoint_timestamps=None,
64                                    last_preserved_timestamp=None):
65  """Generates a checkpoint state proto.
66
67  Args:
68    save_dir: Directory where the model was saved.
69    model_checkpoint_path: The checkpoint file.
70    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
71      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
72      the last element must be equal to model_checkpoint_path.  These paths
73      are also saved in the CheckpointState proto.
74    all_model_checkpoint_timestamps: A list of floats, indicating the number of
75      seconds since the Epoch when each checkpoint was generated.
76    last_preserved_timestamp: A float, indicating the number of seconds since
77      the Epoch when the last preserved checkpoint was written, e.g. due to a
78      `keep_checkpoint_every_n_hours` parameter (see
79      `tf.contrib.checkpoint.CheckpointManager` for an implementation).
80  Returns:
81    CheckpointState proto with model_checkpoint_path and
82    all_model_checkpoint_paths updated to either absolute paths or
83    relative paths to the current save_dir.
84
85  Raises:
86    ValueError: If `all_model_checkpoint_timestamps` was provided but its length
87      does not match `all_model_checkpoint_paths`.
88  """
89  if all_model_checkpoint_paths is None:
90    all_model_checkpoint_paths = []
91
92  if (not all_model_checkpoint_paths or
93      all_model_checkpoint_paths[-1] != model_checkpoint_path):
94    logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
95                 model_checkpoint_path)
96    all_model_checkpoint_paths.append(model_checkpoint_path)
97
98  if (all_model_checkpoint_timestamps
99      and (len(all_model_checkpoint_timestamps)
100           != len(all_model_checkpoint_paths))):
101    raise ValueError(
102        ("Checkpoint timestamps, if provided, must match checkpoint paths (got "
103         "paths %s and timestamps %s)")
104        % (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
105
106  # Relative paths need to be rewritten to be relative to the "save_dir"
107  # if model_checkpoint_path already contains "save_dir".
108  if not os.path.isabs(save_dir):
109    if not os.path.isabs(model_checkpoint_path):
110      model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
111    for i in range(len(all_model_checkpoint_paths)):
112      p = all_model_checkpoint_paths[i]
113      if not os.path.isabs(p):
114        all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
115
116  coord_checkpoint_proto = CheckpointState(
117      model_checkpoint_path=model_checkpoint_path,
118      all_model_checkpoint_paths=all_model_checkpoint_paths,
119      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
120      last_preserved_timestamp=last_preserved_timestamp)
121
122  return coord_checkpoint_proto
123
124
125@deprecation.deprecated(
126    date=None,
127    instructions=("Use `tf.train.CheckpointManager` to manage checkpoints "
128                  "rather than manually editing the Checkpoint proto."))
129@tf_export(v1=["train.update_checkpoint_state"])
130def update_checkpoint_state(save_dir,
131                            model_checkpoint_path,
132                            all_model_checkpoint_paths=None,
133                            latest_filename=None,
134                            all_model_checkpoint_timestamps=None,
135                            last_preserved_timestamp=None):
136  """Updates the content of the 'checkpoint' file.
137
138  This updates the checkpoint file containing a CheckpointState
139  proto.
140
141  Args:
142    save_dir: Directory where the model was saved.
143    model_checkpoint_path: The checkpoint file.
144    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
145      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
146      the last element must be equal to model_checkpoint_path.  These paths
147      are also saved in the CheckpointState proto.
148    latest_filename: Optional name of the checkpoint file.  Default to
149      'checkpoint'.
150    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
151      seconds since the Epoch) indicating when the checkpoints in
152      `all_model_checkpoint_paths` were created.
153    last_preserved_timestamp: A float, indicating the number of seconds since
154      the Epoch when the last preserved checkpoint was written, e.g. due to a
155      `keep_checkpoint_every_n_hours` parameter (see
156      `tf.contrib.checkpoint.CheckpointManager` for an implementation).
157  Raises:
158    RuntimeError: If any of the model checkpoint paths conflict with the file
159      containing CheckpointSate.
160  """
161  update_checkpoint_state_internal(
162      save_dir=save_dir,
163      model_checkpoint_path=model_checkpoint_path,
164      all_model_checkpoint_paths=all_model_checkpoint_paths,
165      latest_filename=latest_filename,
166      save_relative_paths=False,
167      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
168      last_preserved_timestamp=last_preserved_timestamp)
169
170
171def update_checkpoint_state_internal(save_dir,
172                                     model_checkpoint_path,
173                                     all_model_checkpoint_paths=None,
174                                     latest_filename=None,
175                                     save_relative_paths=False,
176                                     all_model_checkpoint_timestamps=None,
177                                     last_preserved_timestamp=None):
178  """Updates the content of the 'checkpoint' file.
179
180  This updates the checkpoint file containing a CheckpointState
181  proto.
182
183  Args:
184    save_dir: Directory where the model was saved.
185    model_checkpoint_path: The checkpoint file.
186    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
187      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
188      the last element must be equal to model_checkpoint_path.  These paths
189      are also saved in the CheckpointState proto.
190    latest_filename: Optional name of the checkpoint file.  Default to
191      'checkpoint'.
192    save_relative_paths: If `True`, will write relative paths to the checkpoint
193      state file.
194    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
195      seconds since the Epoch) indicating when the checkpoints in
196      `all_model_checkpoint_paths` were created.
197    last_preserved_timestamp: A float, indicating the number of seconds since
198      the Epoch when the last preserved checkpoint was written, e.g. due to a
199      `keep_checkpoint_every_n_hours` parameter (see
200      `tf.contrib.checkpoint.CheckpointManager` for an implementation).
201
202  Raises:
203    RuntimeError: If any of the model checkpoint paths conflict with the file
204      containing CheckpointSate.
205  """
206  # Writes the "checkpoint" file for the coordinator for later restoration.
207  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
208  if save_relative_paths:
209    if os.path.isabs(model_checkpoint_path):
210      rel_model_checkpoint_path = os.path.relpath(
211          model_checkpoint_path, save_dir)
212    else:
213      rel_model_checkpoint_path = model_checkpoint_path
214    rel_all_model_checkpoint_paths = []
215    for p in all_model_checkpoint_paths:
216      if os.path.isabs(p):
217        rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
218      else:
219        rel_all_model_checkpoint_paths.append(p)
220    ckpt = generate_checkpoint_state_proto(
221        save_dir,
222        rel_model_checkpoint_path,
223        all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
224        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
225        last_preserved_timestamp=last_preserved_timestamp)
226  else:
227    ckpt = generate_checkpoint_state_proto(
228        save_dir,
229        model_checkpoint_path,
230        all_model_checkpoint_paths=all_model_checkpoint_paths,
231        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
232        last_preserved_timestamp=last_preserved_timestamp)
233
234  if coord_checkpoint_filename == ckpt.model_checkpoint_path:
235    raise RuntimeError("Save path '%s' conflicts with path used for "
236                       "checkpoint state.  Please use a different save path." %
237                       model_checkpoint_path)
238
239  # Preventing potential read/write race condition by *atomically* writing to a
240  # file.
241  file_io.atomic_write_string_to_file(coord_checkpoint_filename,
242                                      text_format.MessageToString(ckpt))
243
244
245@tf_export("train.get_checkpoint_state")
246def get_checkpoint_state(checkpoint_dir, latest_filename=None):
247  """Returns CheckpointState proto from the "checkpoint" file.
248
249  If the "checkpoint" file contains a valid CheckpointState
250  proto, returns it.
251
252  Args:
253    checkpoint_dir: The directory of checkpoints.
254    latest_filename: Optional name of the checkpoint file.  Default to
255      'checkpoint'.
256
257  Returns:
258    A CheckpointState if the state was available, None
259    otherwise.
260
261  Raises:
262    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
263  """
264  ckpt = None
265  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
266                                                     latest_filename)
267  f = None
268  try:
269    # Check that the file exists before opening it to avoid
270    # many lines of errors from colossus in the logs.
271    if file_io.file_exists(coord_checkpoint_filename):
272      file_content = file_io.read_file_to_string(
273          coord_checkpoint_filename)
274      ckpt = CheckpointState()
275      text_format.Merge(file_content, ckpt)
276      if not ckpt.model_checkpoint_path:
277        raise ValueError("Invalid checkpoint state loaded from "
278                         + checkpoint_dir)
279      # For relative model_checkpoint_path and all_model_checkpoint_paths,
280      # prepend checkpoint_dir.
281      if not os.path.isabs(ckpt.model_checkpoint_path):
282        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
283                                                  ckpt.model_checkpoint_path)
284      for i in range(len(ckpt.all_model_checkpoint_paths)):
285        p = ckpt.all_model_checkpoint_paths[i]
286        if not os.path.isabs(p):
287          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
288  except errors.OpError as e:
289    # It's ok if the file cannot be read
290    logging.warning("%s: %s", type(e).__name__, e)
291    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
292    return None
293  except text_format.ParseError as e:
294    logging.warning("%s: %s", type(e).__name__, e)
295    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
296    return None
297  finally:
298    if f:
299      f.close()
300  return ckpt
301
302
303def _prefix_to_checkpoint_path(prefix, format_version):
304  """Returns the pathname of a checkpoint file, given the checkpoint prefix.
305
306  For V1 checkpoint, simply returns the prefix itself (the data file).  For V2,
307  returns the pathname to the index file.
308
309  Args:
310    prefix: a string, the prefix of a checkpoint.
311    format_version: the checkpoint format version that corresponds to the
312      prefix.
313  Returns:
314    The pathname of a checkpoint file, taking into account the checkpoint
315      format version.
316  """
317  if format_version == saver_pb2.SaverDef.V2:
318    return prefix + ".index"  # The index file identifies a checkpoint.
319  return prefix  # Just the data file.
320
321
322@tf_export("train.latest_checkpoint")
323def latest_checkpoint(checkpoint_dir, latest_filename=None):
324  """Finds the filename of latest saved checkpoint file.
325
326  Args:
327    checkpoint_dir: Directory where the variables were saved.
328    latest_filename: Optional name for the protocol buffer file that
329      contains the list of most recent checkpoint filenames.
330      See the corresponding argument to `Saver.save()`.
331
332  Returns:
333    The full path to the latest checkpoint or `None` if no checkpoint was found.
334  """
335  # Pick the latest checkpoint based on checkpoint state.
336  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
337  if ckpt and ckpt.model_checkpoint_path:
338    # Look for either a V2 path or a V1 path, with priority for V2.
339    v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
340                                         saver_pb2.SaverDef.V2)
341    v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
342                                         saver_pb2.SaverDef.V1)
343    if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
344        v1_path):
345      return ckpt.model_checkpoint_path
346    else:
347      logging.error("Couldn't match files for checkpoint %s",
348                    ckpt.model_checkpoint_path)
349  return None
350
351
352@deprecation.deprecated(
353    date=None,
354    instructions="Use standard file APIs to check for files with this prefix.")
355@tf_export(v1=["train.checkpoint_exists"])
356def checkpoint_exists(checkpoint_prefix):
357  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
358
359  This is the recommended way to check if a checkpoint exists, since it takes
360  into account the naming difference between V1 and V2 formats.
361
362  Args:
363    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
364      priority.  Typically the result of `Saver.save()` or that of
365      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
366      V1/V2.
367  Returns:
368    A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
369  """
370  pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
371                                        saver_pb2.SaverDef.V2)
372  if file_io.get_matching_files(pathname):
373    return True
374  elif file_io.get_matching_files(checkpoint_prefix):
375    return True
376  else:
377    return False
378
379
380@deprecation.deprecated(
381    date=None,
382    instructions="Use standard file utilities to get mtimes.")
383@tf_export(v1=["train.get_checkpoint_mtimes"])
384def get_checkpoint_mtimes(checkpoint_prefixes):
385  """Returns the mtimes (modification timestamps) of the checkpoints.
386
387  Globs for the checkpoints pointed to by `checkpoint_prefixes`.  If the files
388  exist, collect their mtime.  Both V2 and V1 checkpoints are considered, in
389  that priority.
390
391  This is the recommended way to get the mtimes, since it takes into account
392  the naming difference between V1 and V2 formats.
393
394  Args:
395    checkpoint_prefixes: a list of checkpoint paths, typically the results of
396      `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
397      sharded/non-sharded or V1/V2.
398  Returns:
399    A list of mtimes (in microseconds) of the found checkpoints.
400  """
401  mtimes = []
402
403  def match_maybe_append(pathname):
404    fnames = file_io.get_matching_files(pathname)
405    if fnames:
406      mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
407      return True
408    return False
409
410  for checkpoint_prefix in checkpoint_prefixes:
411    # Tries V2's metadata file first.
412    pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
413                                          saver_pb2.SaverDef.V2)
414    if match_maybe_append(pathname):
415      continue
416    # Otherwise, tries V1, where the prefix is the complete pathname.
417    match_maybe_append(checkpoint_prefix)
418
419  return mtimes
420
421
422@deprecation.deprecated(
423    date=None,
424    instructions="Use standard file APIs to delete files with this prefix.")
425@tf_export(v1=["train.remove_checkpoint"])
426def remove_checkpoint(checkpoint_prefix,
427                      checkpoint_format_version=saver_pb2.SaverDef.V2,
428                      meta_graph_suffix="meta"):
429  """Removes a checkpoint given by `checkpoint_prefix`.
430
431  Args:
432    checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
433      of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
434      sharded/non-sharded or V1/V2.
435    checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
436      `SaverDef.V2`.
437    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
438  """
439  _delete_file_if_exists(
440      meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
441  if checkpoint_format_version == saver_pb2.SaverDef.V2:
442    # V2 has a metadata file and some data files.
443    _delete_file_if_exists(checkpoint_prefix + ".index")
444    _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
445  else:
446    # V1, Legacy.  Exact match on the data file.
447    _delete_file_if_exists(checkpoint_prefix)
448
449
450def _delete_file_if_exists(filespec):
451  """Deletes files matching `filespec`."""
452  for pathname in file_io.get_matching_files(filespec):
453    file_io.delete_file(pathname)
454
455
456def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
457  """Returns the meta graph filename.
458
459  Args:
460    checkpoint_filename: Name of the checkpoint file.
461    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
462
463  Returns:
464    MetaGraph file name.
465  """
466  # If the checkpoint_filename is sharded, the checkpoint_filename could
467  # be of format model.ckpt-step#-?????-of-shard#. For example,
468  # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
469  basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
470  suffixed_filename = ".".join([basename, meta_graph_suffix])
471  return suffixed_filename
472
473
474# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
475@tf_export("train.CheckpointManager")
476class CheckpointManager(object):
477  """Deletes old checkpoints.
478
479  Example usage:
480  ```python
481  import tensorflow as tf
482  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
483  manager = tf.contrib.checkpoint.CheckpointManager(
484      checkpoint, directory="/tmp/model", max_to_keep=5)
485  status = checkpoint.restore(manager.latest_checkpoint)
486  while True:
487    # train
488    manager.save()
489  ```
490
491  `CheckpointManager` preserves its own state across instantiations (see the
492  `__init__` documentation for details). Only one should be active in a
493  particular directory at a time.
494  """
495
496  def __init__(self, checkpoint, directory,
497               max_to_keep, keep_checkpoint_every_n_hours=None):
498    """Configure a `CheckpointManager` for use in `directory`.
499
500    If a `CheckpointManager` was previously used in `directory`, its
501    state will be restored. This includes the list of managed checkpoints and
502    the timestamp bookkeeping necessary to support
503    `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
504    will be the same as the previous `CheckpointManager`, including cleaning up
505    existing checkpoints if appropriate.
506
507    Checkpoints are only considered for deletion just after a new checkpoint has
508    been added. At that point, `max_to_keep` checkpoints will remain in an
509    "active set". Once a checkpoint is preserved by
510    `keep_checkpoint_every_n_hours` it will not be deleted by this
511    `CheckpointManager` or any future `CheckpointManager` instantiated in
512    `directory` (regardless of the new setting of
513    `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
514    active set may be deleted by this `CheckpointManager` or a future
515    `CheckpointManager` instantiated in `directory` (subject to its
516    `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
517
518    Args:
519      checkpoint: The `tf.train.Checkpoint` instance to save and manage
520        checkpoints for.
521      directory: The path to a directory in which to write checkpoints. A
522        special file named "checkpoint" is also written to this directory (in a
523        human-readable text format) which contains the state of the
524        `CheckpointManager`.
525      max_to_keep: An integer, the number of checkpoints to keep. Unless
526        preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
527        deleted from the active set, oldest first, until only `max_to_keep`
528        checkpoints remain. If `None`, no checkpoints are deleted and everything
529        stays in the active set. Note that `max_to_keep=None` will keep all
530        checkpoint paths in memory and in the checkpoint state protocol buffer
531        on disk.
532      keep_checkpoint_every_n_hours: Upon removal from the active set, a
533        checkpoint will be preserved if it has been at least
534        `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
535        default setting of `None` does not preserve any checkpoints in this way.
536
537    Raises:
538      ValueError: If `max_to_keep` is not a positive integer.
539    """
540    self._checkpoint = checkpoint
541    self._save_counter_assign = None
542    if max_to_keep is not None and max_to_keep <= 0:
543      raise ValueError(
544          ("Expected a positive integer or `None` for `max_to_max_to_keep`, "
545           "got %d.")
546          % (max_to_keep,))
547    self._max_to_keep = max_to_keep
548    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
549    self._directory = directory
550    self._checkpoint_prefix = os.path.join(directory, "ckpt")
551    recovered_state = get_checkpoint_state(directory)
552    current_clock = time.time()
553    self._maybe_delete = collections.OrderedDict()
554    if recovered_state is None:
555      self._latest_checkpoint = None
556      # Set the clock back slightly to avoid race conditions when quckly
557      # re-creating a CheckpointManager.
558      self._last_preserved_timestamp = current_clock - 1.
559    else:
560      self._latest_checkpoint = recovered_state.model_checkpoint_path
561      self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
562      if current_clock < self._last_preserved_timestamp:
563        # Time seems to have reversed itself. In addition to this warning, we'll
564        # min() saved checkpoint timestamps with the current time to ensure that
565        # old checkpoints don't get deleted accidentally.
566        logging.warning(
567            ("time.time() returned a value %f seconds behind the last "
568             "preserved checkpoint timestamp.")
569            % (self._last_preserved_timestamp - current_clock,))
570        self._last_preserved_timestamp = current_clock
571      all_timestamps = recovered_state.all_model_checkpoint_timestamps
572      all_paths = recovered_state.all_model_checkpoint_paths
573      del recovered_state  # Uses modified values from now on
574      if not all_timestamps:
575        all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
576
577      for filename, timestamp in zip(all_paths, all_timestamps):
578        timestamp = min(timestamp, current_clock)
579        if timestamp > self._last_preserved_timestamp:
580          self._maybe_delete[filename] = timestamp
581
582  @property
583  def latest_checkpoint(self):
584    """The prefix of the most recent checkpoint in `directory`.
585
586    Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
587    the constructor argument to `CheckpointManager`.
588
589    Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
590
591    Returns:
592      The checkpoint prefix. If there are no checkpoints, returns `None`.
593    """
594    return self._latest_checkpoint
595
596  @property
597  def checkpoints(self):
598    """A list of managed checkpoints.
599
600    Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
601    show up in this list (to avoid ever-growing filename lists).
602
603    Returns:
604      A list of filenames, sorted from oldest to newest.
605    """
606    return list(self._maybe_delete.keys())
607
608  def _sweep(self):
609    """Deletes or preserves managed checkpoints."""
610    if not self._max_to_keep:
611      # Does not update self._last_preserved_timestamp, since everything is kept
612      # in the active set.
613      return
614    while len(self._maybe_delete) > self._max_to_keep:
615      filename, timestamp = self._maybe_delete.popitem(last=False)
616      # Even if we're keeping this checkpoint due to
617      # keep_checkpoint_every_n_hours, we won't reference it to avoid
618      # infinitely-growing CheckpointState protos.
619      if (self._keep_checkpoint_every_n_hours
620          and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
621               >= self._last_preserved_timestamp)):
622        self._last_preserved_timestamp = timestamp
623        continue
624      _delete_file_if_exists(filename + ".index")
625      _delete_file_if_exists(filename + ".data-?????-of-?????")
626
627  def _record_state(self):
628    """Saves the `CheckpointManager`'s state in `directory`."""
629    filenames, timestamps = zip(*self._maybe_delete.items())
630    update_checkpoint_state_internal(
631        self._directory,
632        model_checkpoint_path=self.latest_checkpoint,
633        all_model_checkpoint_paths=filenames,
634        all_model_checkpoint_timestamps=timestamps,
635        last_preserved_timestamp=self._last_preserved_timestamp,
636        save_relative_paths=True)
637
638  @property
639  def _prefix(self):
640    """A common prefix for all checkpoints saved with this manager.
641
642    For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
643    `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
644    numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
645    checkpoint has several associated files
646    (e.g. `"/tmp/tf-model/ckpt-2.index"`).
647
648    Returns:
649      A string prefix.
650    """
651    return self._checkpoint_prefix
652
653  def save(self, checkpoint_number=None):
654    """Creates a new checkpoint and manages it.
655
656    Args:
657      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
658        `Tensor`, used to number the checkpoint. If `None` (default),
659        checkpoints are numbered using `checkpoint.save_counter`. Even if
660        `checkpoint_number` is provided, `save_counter` is still incremented. A
661        user-provided `checkpoint_number` is not incremented even if it is a
662        `Variable`.
663
664    Returns:
665      The path to the new checkpoint. It is also recorded in the `checkpoints`
666      and `latest_checkpoint` properies.
667    """
668    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
669    # slightly with a custom numbering option.
670    if context.executing_eagerly():
671      save_counter = self._checkpoint.save_counter
672      save_counter.assign_add(1)
673      session = None
674    else:
675      session = ops.get_default_session()
676
677      def _initializing_creator(next_creator, **kwargs):
678        """Initialize the save counter if it has been newly created."""
679        v = next_creator(**kwargs)
680        session.run(v.initializer)
681        return v
682
683      with variable_scope.variable_creator_scope(_initializing_creator):
684        save_counter = self._checkpoint.save_counter
685      if self._save_counter_assign is None:
686        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
687      session.run(self._save_counter_assign)
688    if checkpoint_number is None:
689      checkpoint_number = save_counter
690    if not isinstance(checkpoint_number, compat.integral_types):
691      checkpoint_number = training_util.global_step(
692          sess=session, global_step_tensor=checkpoint_number)
693    prefix = "%s-%d" % (self._prefix, checkpoint_number)
694    save_path = self._checkpoint.write(prefix)
695    timestamp = time.time()
696    # If this is an overwritten checkpoint we were previously tracking, delete
697    # and reinsert it to make sure it goes to the end of the queue.
698    if save_path in self._maybe_delete:
699      del self._maybe_delete[save_path]
700    self._maybe_delete[save_path] = timestamp
701    self._latest_checkpoint = save_path
702    self._sweep()
703    self._record_state()
704    return save_path
705