• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import os.path
24import re
25import time
26
27from google.protobuf import text_format
28
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.eager import context
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.lib.io import file_io
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.training import training_util
37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
38from tensorflow.python.util import compat
39from tensorflow.python.util import deprecation
40from tensorflow.python.util.tf_export import tf_export
41
42
43def _evaluate(tensor):
44  """Returns the numpy value of a tensor."""
45  if context.executing_eagerly():
46    return tensor.numpy()
47  return ops.get_default_session().run(tensor)
48
49
50def _GetCheckpointFilename(save_dir, latest_filename):
51  """Returns a filename for storing the CheckpointState.
52
53  Args:
54    save_dir: The directory for saving and restoring checkpoints.
55    latest_filename: Name of the file in 'save_dir' that is used
56      to store the CheckpointState.
57
58  Returns:
59    The path of the file that contains the CheckpointState proto.
60  """
61  if latest_filename is None:
62    latest_filename = "checkpoint"
63  return os.path.join(save_dir, latest_filename)
64
65
66@tf_export(v1=["train.generate_checkpoint_state_proto"])
67def generate_checkpoint_state_proto(save_dir,
68                                    model_checkpoint_path,
69                                    all_model_checkpoint_paths=None,
70                                    all_model_checkpoint_timestamps=None,
71                                    last_preserved_timestamp=None):
72  """Generates a checkpoint state proto.
73
74  Args:
75    save_dir: Directory where the model was saved.
76    model_checkpoint_path: The checkpoint file.
77    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
78      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
79      the last element must be equal to model_checkpoint_path.  These paths
80      are also saved in the CheckpointState proto.
81    all_model_checkpoint_timestamps: A list of floats, indicating the number of
82      seconds since the Epoch when each checkpoint was generated.
83    last_preserved_timestamp: A float, indicating the number of seconds since
84      the Epoch when the last preserved checkpoint was written, e.g. due to a
85      `keep_checkpoint_every_n_hours` parameter (see
86      `tf.train.CheckpointManager` for an implementation).
87  Returns:
88    CheckpointState proto with model_checkpoint_path and
89    all_model_checkpoint_paths updated to either absolute paths or
90    relative paths to the current save_dir.
91
92  Raises:
93    ValueError: If `all_model_checkpoint_timestamps` was provided but its length
94      does not match `all_model_checkpoint_paths`.
95  """
96  if all_model_checkpoint_paths is None:
97    all_model_checkpoint_paths = []
98
99  if (not all_model_checkpoint_paths or
100      all_model_checkpoint_paths[-1] != model_checkpoint_path):
101    logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
102                 model_checkpoint_path)
103    all_model_checkpoint_paths.append(model_checkpoint_path)
104
105  if (all_model_checkpoint_timestamps
106      and (len(all_model_checkpoint_timestamps)
107           != len(all_model_checkpoint_paths))):
108    raise ValueError(
109        ("Checkpoint timestamps, if provided, must match checkpoint paths (got "
110         "paths %s and timestamps %s)")
111        % (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
112
113  # Relative paths need to be rewritten to be relative to the "save_dir"
114  # if model_checkpoint_path already contains "save_dir".
115  if not os.path.isabs(save_dir):
116    if not os.path.isabs(model_checkpoint_path):
117      model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
118    for i, p in enumerate(all_model_checkpoint_paths):
119      if not os.path.isabs(p):
120        all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
121
122  coord_checkpoint_proto = CheckpointState(
123      model_checkpoint_path=model_checkpoint_path,
124      all_model_checkpoint_paths=all_model_checkpoint_paths,
125      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
126      last_preserved_timestamp=last_preserved_timestamp)
127
128  return coord_checkpoint_proto
129
130
131@deprecation.deprecated(
132    date=None,
133    instructions=("Use `tf.train.CheckpointManager` to manage checkpoints "
134                  "rather than manually editing the Checkpoint proto."))
135@tf_export(v1=["train.update_checkpoint_state"])
136def update_checkpoint_state(save_dir,
137                            model_checkpoint_path,
138                            all_model_checkpoint_paths=None,
139                            latest_filename=None,
140                            all_model_checkpoint_timestamps=None,
141                            last_preserved_timestamp=None):
142  """Updates the content of the 'checkpoint' file.
143
144  This updates the checkpoint file containing a CheckpointState
145  proto.
146
147  Args:
148    save_dir: Directory where the model was saved.
149    model_checkpoint_path: The checkpoint file.
150    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
151      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
152      the last element must be equal to model_checkpoint_path.  These paths
153      are also saved in the CheckpointState proto.
154    latest_filename: Optional name of the checkpoint file.  Default to
155      'checkpoint'.
156    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
157      seconds since the Epoch) indicating when the checkpoints in
158      `all_model_checkpoint_paths` were created.
159    last_preserved_timestamp: A float, indicating the number of seconds since
160      the Epoch when the last preserved checkpoint was written, e.g. due to a
161      `keep_checkpoint_every_n_hours` parameter (see
162      `tf.train.CheckpointManager` for an implementation).
163  Raises:
164    RuntimeError: If any of the model checkpoint paths conflict with the file
165      containing CheckpointSate.
166  """
167  update_checkpoint_state_internal(
168      save_dir=save_dir,
169      model_checkpoint_path=model_checkpoint_path,
170      all_model_checkpoint_paths=all_model_checkpoint_paths,
171      latest_filename=latest_filename,
172      save_relative_paths=False,
173      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
174      last_preserved_timestamp=last_preserved_timestamp)
175
176
177def update_checkpoint_state_internal(save_dir,
178                                     model_checkpoint_path,
179                                     all_model_checkpoint_paths=None,
180                                     latest_filename=None,
181                                     save_relative_paths=False,
182                                     all_model_checkpoint_timestamps=None,
183                                     last_preserved_timestamp=None):
184  """Updates the content of the 'checkpoint' file.
185
186  This updates the checkpoint file containing a CheckpointState
187  proto.
188
189  Args:
190    save_dir: Directory where the model was saved.
191    model_checkpoint_path: The checkpoint file.
192    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
193      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
194      the last element must be equal to model_checkpoint_path.  These paths
195      are also saved in the CheckpointState proto.
196    latest_filename: Optional name of the checkpoint file.  Default to
197      'checkpoint'.
198    save_relative_paths: If `True`, will write relative paths to the checkpoint
199      state file.
200    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
201      seconds since the Epoch) indicating when the checkpoints in
202      `all_model_checkpoint_paths` were created.
203    last_preserved_timestamp: A float, indicating the number of seconds since
204      the Epoch when the last preserved checkpoint was written, e.g. due to a
205      `keep_checkpoint_every_n_hours` parameter (see
206      `tf.train.CheckpointManager` for an implementation).
207
208  Raises:
209    RuntimeError: If any of the model checkpoint paths conflict with the file
210      containing CheckpointSate.
211  """
212  # Writes the "checkpoint" file for the coordinator for later restoration.
213  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
214  if save_relative_paths:
215    if os.path.isabs(model_checkpoint_path):
216      rel_model_checkpoint_path = os.path.relpath(
217          model_checkpoint_path, save_dir)
218    else:
219      rel_model_checkpoint_path = model_checkpoint_path
220    rel_all_model_checkpoint_paths = []
221    for p in all_model_checkpoint_paths:
222      if os.path.isabs(p):
223        rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
224      else:
225        rel_all_model_checkpoint_paths.append(p)
226    ckpt = generate_checkpoint_state_proto(
227        save_dir,
228        rel_model_checkpoint_path,
229        all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
230        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
231        last_preserved_timestamp=last_preserved_timestamp)
232  else:
233    ckpt = generate_checkpoint_state_proto(
234        save_dir,
235        model_checkpoint_path,
236        all_model_checkpoint_paths=all_model_checkpoint_paths,
237        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
238        last_preserved_timestamp=last_preserved_timestamp)
239
240  if coord_checkpoint_filename == ckpt.model_checkpoint_path:
241    raise RuntimeError("Save path '%s' conflicts with path used for "
242                       "checkpoint state.  Please use a different save path." %
243                       model_checkpoint_path)
244
245  # Preventing potential read/write race condition by *atomically* writing to a
246  # file.
247  file_io.atomic_write_string_to_file(coord_checkpoint_filename,
248                                      text_format.MessageToString(ckpt))
249
250
251@tf_export("train.get_checkpoint_state")
252def get_checkpoint_state(checkpoint_dir, latest_filename=None):
253  """Returns CheckpointState proto from the "checkpoint" file.
254
255  If the "checkpoint" file contains a valid CheckpointState
256  proto, returns it.
257
258  Args:
259    checkpoint_dir: The directory of checkpoints.
260    latest_filename: Optional name of the checkpoint file.  Default to
261      'checkpoint'.
262
263  Returns:
264    A CheckpointState if the state was available, None
265    otherwise.
266
267  Raises:
268    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
269  """
270  ckpt = None
271  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
272                                                     latest_filename)
273  f = None
274  try:
275    # Check that the file exists before opening it to avoid
276    # many lines of errors from colossus in the logs.
277    if file_io.file_exists(coord_checkpoint_filename):
278      file_content = file_io.read_file_to_string(
279          coord_checkpoint_filename)
280      ckpt = CheckpointState()
281      text_format.Merge(file_content, ckpt)
282      if not ckpt.model_checkpoint_path:
283        raise ValueError("Invalid checkpoint state loaded from "
284                         + checkpoint_dir)
285      # For relative model_checkpoint_path and all_model_checkpoint_paths,
286      # prepend checkpoint_dir.
287      if not os.path.isabs(ckpt.model_checkpoint_path):
288        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
289                                                  ckpt.model_checkpoint_path)
290      for i, p in enumerate(ckpt.all_model_checkpoint_paths):
291        if not os.path.isabs(p):
292          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
293  except errors.OpError as e:
294    # It's ok if the file cannot be read
295    logging.warning("%s: %s", type(e).__name__, e)
296    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
297    return None
298  except text_format.ParseError as e:
299    logging.warning("%s: %s", type(e).__name__, e)
300    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
301    return None
302  finally:
303    if f:
304      f.close()
305  return ckpt
306
307
308def _prefix_to_checkpoint_path(prefix, format_version):
309  """Returns the pathname of a checkpoint file, given the checkpoint prefix.
310
311  For V1 checkpoint, simply returns the prefix itself (the data file).  For V2,
312  returns the pathname to the index file.
313
314  Args:
315    prefix: a string, the prefix of a checkpoint.
316    format_version: the checkpoint format version that corresponds to the
317      prefix.
318  Returns:
319    The pathname of a checkpoint file, taking into account the checkpoint
320      format version.
321  """
322  if format_version == saver_pb2.SaverDef.V2:
323    return prefix + ".index"  # The index file identifies a checkpoint.
324  return prefix  # Just the data file.
325
326
327@tf_export("train.latest_checkpoint")
328def latest_checkpoint(checkpoint_dir, latest_filename=None):
329  """Finds the filename of latest saved checkpoint file.
330
331  Gets the checkpoint state given the provided checkpoint_dir and looks for a
332  corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path.
333  The latest_filename argument is only applicable if you are saving checkpoint
334  using `v1.train.Saver.save`
335
336
337  See the [Training Checkpoints
338  Guide](https://www.tensorflow.org/guide/checkpoint) for more details and
339  examples.`
340
341  Args:
342    checkpoint_dir: Directory where the variables were saved.
343    latest_filename: Optional name for the protocol buffer file that
344      contains the list of most recent checkpoint filenames.
345      See the corresponding argument to `v1.train.Saver.save`.
346
347  Returns:
348    The full path to the latest checkpoint or `None` if no checkpoint was found.
349  """
350  # Pick the latest checkpoint based on checkpoint state.
351  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
352  if ckpt and ckpt.model_checkpoint_path:
353    # Look for either a V2 path or a V1 path, with priority for V2.
354    v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
355                                         saver_pb2.SaverDef.V2)
356    v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
357                                         saver_pb2.SaverDef.V1)
358    if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
359        v1_path):
360      return ckpt.model_checkpoint_path
361    else:
362      logging.error("Couldn't match files for checkpoint %s",
363                    ckpt.model_checkpoint_path)
364  return None
365
366
367def checkpoint_exists_internal(checkpoint_prefix):
368  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
369
370  This is an internal function to check if a checkpoint exists,
371  since it takes into account the naming difference between V1 and V2 formats.
372
373  Args:
374    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
375      priority.  Typically the result of `Saver.save()` or that of
376      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
377      V1/V2.
378  Returns:
379    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
380  """
381  pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
382                                        saver_pb2.SaverDef.V2)
383  if file_io.get_matching_files(pathname):
384    return True
385  elif file_io.get_matching_files(checkpoint_prefix):
386    return True
387  else:
388    return False
389
390
391@deprecation.deprecated(
392    date=None,
393    instructions="Use standard file APIs to check for files with this prefix.")
394@tf_export(v1=["train.checkpoint_exists"])
395def checkpoint_exists(checkpoint_prefix):
396  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
397
398  This is the recommended way to check if a checkpoint exists, since it takes
399  into account the naming difference between V1 and V2 formats.
400
401  Args:
402    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
403      priority.  Typically the result of `Saver.save()` or that of
404      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
405      V1/V2.
406
407  Returns:
408    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
409  """
410  return checkpoint_exists_internal(checkpoint_prefix)
411
412
413@deprecation.deprecated(
414    date=None,
415    instructions="Use standard file utilities to get mtimes.")
416@tf_export(v1=["train.get_checkpoint_mtimes"])
417def get_checkpoint_mtimes(checkpoint_prefixes):
418  """Returns the mtimes (modification timestamps) of the checkpoints.
419
420  Globs for the checkpoints pointed to by `checkpoint_prefixes`.  If the files
421  exist, collect their mtime.  Both V2 and V1 checkpoints are considered, in
422  that priority.
423
424  This is the recommended way to get the mtimes, since it takes into account
425  the naming difference between V1 and V2 formats.
426
427  Note: If not all checkpoints exist, the length of the returned mtimes list
428  will be smaller than the length of `checkpoint_prefixes` list, so mapping
429  checkpoints to corresponding mtimes will not be possible.
430
431  Args:
432    checkpoint_prefixes: a list of checkpoint paths, typically the results of
433      `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
434      sharded/non-sharded or V1/V2.
435  Returns:
436    A list of mtimes (in microseconds) of the found checkpoints.
437  """
438  mtimes = []
439
440  def match_maybe_append(pathname):
441    fnames = file_io.get_matching_files(pathname)
442    if fnames:
443      mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
444      return True
445    return False
446
447  for checkpoint_prefix in checkpoint_prefixes:
448    # Tries V2's metadata file first.
449    pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
450                                          saver_pb2.SaverDef.V2)
451    if match_maybe_append(pathname):
452      continue
453    # Otherwise, tries V1, where the prefix is the complete pathname.
454    match_maybe_append(checkpoint_prefix)
455
456  return mtimes
457
458
459@deprecation.deprecated(
460    date=None,
461    instructions="Use standard file APIs to delete files with this prefix.")
462@tf_export(v1=["train.remove_checkpoint"])
463def remove_checkpoint(checkpoint_prefix,
464                      checkpoint_format_version=saver_pb2.SaverDef.V2,
465                      meta_graph_suffix="meta"):
466  """Removes a checkpoint given by `checkpoint_prefix`.
467
468  Args:
469    checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
470      of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
471      sharded/non-sharded or V1/V2.
472    checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
473      `SaverDef.V2`.
474    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
475  """
476  _delete_file_if_exists(
477      meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
478  if checkpoint_format_version == saver_pb2.SaverDef.V2:
479    # V2 has a metadata file and some data files.
480    _delete_file_if_exists(checkpoint_prefix + ".index")
481    _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
482  else:
483    # V1, Legacy.  Exact match on the data file.
484    _delete_file_if_exists(checkpoint_prefix)
485
486
487def _delete_file_if_exists(filespec):
488  """Deletes files matching `filespec`."""
489  for pathname in file_io.get_matching_files(filespec):
490    try:
491      file_io.delete_file(pathname)
492    except errors.NotFoundError:
493      logging.warning(
494          "Hit NotFoundError when deleting '%s', possibly because another "
495          "process/thread is also deleting/moving the same file", pathname)
496
497
498def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
499  """Returns the meta graph filename.
500
501  Args:
502    checkpoint_filename: Name of the checkpoint file.
503    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
504
505  Returns:
506    MetaGraph file name.
507  """
508  # If the checkpoint_filename is sharded, the checkpoint_filename could
509  # be of format model.ckpt-step#-?????-of-shard#. For example,
510  # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
511  basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
512  suffixed_filename = ".".join([basename, meta_graph_suffix])
513  return suffixed_filename
514
515
516# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
517@tf_export("train.CheckpointManager")
518class CheckpointManager(object):
519  """Manages multiple checkpoints by keeping some and deleting unneeded ones.
520
521  Example usage:
522
523  ```python
524  import tensorflow as tf
525  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
526  manager = tf.train.CheckpointManager(
527      checkpoint, directory="/tmp/model", max_to_keep=5)
528  status = checkpoint.restore(manager.latest_checkpoint)
529  while True:
530    # train
531    manager.save()
532  ```
533
534  `CheckpointManager` preserves its own state across instantiations (see the
535  `__init__` documentation for details). Only one should be active in a
536  particular directory at a time.
537  """
538
539  def __init__(self,
540               checkpoint,
541               directory,
542               max_to_keep,
543               keep_checkpoint_every_n_hours=None,
544               checkpoint_name="ckpt",
545               step_counter=None,
546               checkpoint_interval=None,
547               init_fn=None):
548    """Configure a `CheckpointManager` for use in `directory`.
549
550    If a `CheckpointManager` was previously used in `directory`, its
551    state will be restored. This includes the list of managed checkpoints and
552    the timestamp bookkeeping necessary to support
553    `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
554    will be the same as the previous `CheckpointManager`, including cleaning up
555    existing checkpoints if appropriate.
556
557    Checkpoints are only considered for deletion just after a new checkpoint has
558    been added. At that point, `max_to_keep` checkpoints will remain in an
559    "active set". Once a checkpoint is preserved by
560    `keep_checkpoint_every_n_hours` it will not be deleted by this
561    `CheckpointManager` or any future `CheckpointManager` instantiated in
562    `directory` (regardless of the new setting of
563    `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
564    active set may be deleted by this `CheckpointManager` or a future
565    `CheckpointManager` instantiated in `directory` (subject to its
566    `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
567
568    `CheckpointManager` can be also used for initializing the model if
569    there is no checkpoints for restoring in `directory`. An example usage is:
570
571    >>> import tempfile
572
573    >>> tmp_dir = tempfile.mkdtemp()
574    >>> checkpoint = tf.train.Checkpoint()
575    >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init'))
576
577    >>> def init_fn():
578    ...   # Partially restore the checkpoint from `init_path`.
579    ...   checkpoint.restore(init_path)
580
581    >>> manager = tf.train.CheckpointManager(
582    ...     checkpoint,
583    ...     directory=os.path.join(tmp_dir, 'ckpt'),
584    ...     max_to_keep=None,
585    ...     init_fn=init_fn)
586    >>> # `restore_or_initialize` will call `init_fn` if there is no existing
587    >>> # checkpoint in `directory`.
588    >>> manager.restore_or_initialize()
589
590    Args:
591      checkpoint: The `tf.train.Checkpoint` instance to save and manage
592        checkpoints for.
593      directory: The path to a directory in which to write checkpoints. A
594        special file named "checkpoint" is also written to this directory (in a
595        human-readable text format) which contains the state of the
596        `CheckpointManager`.
597      max_to_keep: An integer, the number of checkpoints to keep. Unless
598        preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
599        deleted from the active set, oldest first, until only `max_to_keep`
600        checkpoints remain. If `None`, no checkpoints are deleted and everything
601        stays in the active set. Note that `max_to_keep=None` will keep all
602        checkpoint paths in memory and in the checkpoint state protocol buffer
603        on disk.
604      keep_checkpoint_every_n_hours: Upon removal from the active set, a
605        checkpoint will be preserved if it has been at least
606        `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
607        default setting of `None` does not preserve any checkpoints in this way.
608      checkpoint_name: Custom name for the checkpoint file.
609      step_counter: A `tf.Variable` instance for checking the current step
610        counter value, in case users want to save checkpoints every N steps.
611      checkpoint_interval: An integer, indicates the minimum step interval
612        between two checkpoints.
613      init_fn: Callable. A function to do customized intialization if no
614        checkpoints are in the directory.
615
616    Raises:
617      ValueError: If `max_to_keep` is not a positive integer.
618    """
619    self._checkpoint = checkpoint
620    self._save_counter_assign = None
621    if max_to_keep is not None and max_to_keep <= 0:
622      raise ValueError(
623          ("Expected a positive integer or `None` for `max_to_keep`, "
624           "got %d.")
625          % (max_to_keep,))
626    self._max_to_keep = max_to_keep
627    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
628    self._directory = directory
629    self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
630    self._init_fn = init_fn
631
632    if checkpoint_interval is not None:
633      if step_counter is None:
634        raise ValueError("`step_counter` should be passed if "
635                         "`checkpoint_interval` is not None.")
636      self._last_checkpoint_step = None
637      self._step_counter = step_counter
638    self._checkpoint_interval = checkpoint_interval
639
640    recovered_state = get_checkpoint_state(directory)
641    current_clock = time.time()
642    self._maybe_delete = collections.OrderedDict()
643    if recovered_state is None:
644      self._latest_checkpoint = None
645      # Set the clock back slightly to avoid race conditions when quickly
646      # re-creating a CheckpointManager.
647      self._last_preserved_timestamp = current_clock - 1.
648    else:
649      self._latest_checkpoint = recovered_state.model_checkpoint_path
650      self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
651      if current_clock < self._last_preserved_timestamp:
652        # Time seems to have reversed itself. In addition to this warning, we'll
653        # min() saved checkpoint timestamps with the current time to ensure that
654        # old checkpoints don't get deleted accidentally.
655        logging.warning(
656            ("time.time() returned a value %f seconds behind the last "
657             "preserved checkpoint timestamp.")
658            % (self._last_preserved_timestamp - current_clock,))
659        self._last_preserved_timestamp = current_clock
660      all_timestamps = recovered_state.all_model_checkpoint_timestamps
661      all_paths = recovered_state.all_model_checkpoint_paths
662      del recovered_state  # Uses modified values from now on
663      if not all_timestamps:
664        all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
665
666      for filename, timestamp in zip(all_paths, all_timestamps):
667        timestamp = min(timestamp, current_clock)
668        if timestamp > self._last_preserved_timestamp:
669          self._maybe_delete[filename] = timestamp
670
671  @property
672  def directory(self):
673    return self._directory
674
675  @property
676  def checkpoint_interval(self):
677    return self._checkpoint_interval
678
679  @property
680  def latest_checkpoint(self):
681    """The prefix of the most recent checkpoint in `directory`.
682
683    Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
684    the constructor argument to `CheckpointManager`.
685
686    Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
687
688    Returns:
689      The checkpoint prefix. If there are no checkpoints, returns `None`.
690    """
691    return self._latest_checkpoint
692
693  @property
694  def checkpoints(self):
695    """A list of managed checkpoints.
696
697    Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
698    show up in this list (to avoid ever-growing filename lists).
699
700    Returns:
701      A list of filenames, sorted from oldest to newest.
702    """
703    return list(self._maybe_delete.keys())
704
705  def _sweep(self):
706    """Deletes or preserves managed checkpoints."""
707    if not self._max_to_keep:
708      # Does not update self._last_preserved_timestamp, since everything is kept
709      # in the active set.
710      return
711    while len(self._maybe_delete) > self._max_to_keep:
712      filename, timestamp = self._maybe_delete.popitem(last=False)
713      # Even if we're keeping this checkpoint due to
714      # keep_checkpoint_every_n_hours, we won't reference it to avoid
715      # infinitely-growing CheckpointState protos.
716      if (self._keep_checkpoint_every_n_hours
717          and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
718               >= self._last_preserved_timestamp)):
719        self._last_preserved_timestamp = timestamp
720        continue
721      _delete_file_if_exists(filename + ".index")
722      _delete_file_if_exists(filename + ".data-?????-of-?????")
723
724  def _record_state(self):
725    """Saves the `CheckpointManager`'s state in `directory`."""
726    filenames, timestamps = zip(*self._maybe_delete.items())
727    update_checkpoint_state_internal(
728        self._directory,
729        model_checkpoint_path=self.latest_checkpoint,
730        all_model_checkpoint_paths=filenames,
731        all_model_checkpoint_timestamps=timestamps,
732        last_preserved_timestamp=self._last_preserved_timestamp,
733        save_relative_paths=True)
734
735  @property
736  def _prefix(self):
737    """A common prefix for all checkpoints saved with this manager.
738
739    For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
740    `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
741    numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
742    checkpoint has several associated files
743    (e.g. `"/tmp/tf-model/ckpt-2.index"`).
744
745    Returns:
746      A string prefix.
747    """
748    return self._checkpoint_prefix
749
750  @property
751  def checkpoint(self):
752    """Returns the `tf.train.Checkpoint` object."""
753    return self._checkpoint
754
755  def save(self, checkpoint_number=None, check_interval=True, options=None):
756    """Creates a new checkpoint and manages it.
757
758    Args:
759      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
760        `Tensor`, used to number the checkpoint. If `None` (default),
761        checkpoints are numbered using `checkpoint.save_counter`. Even if
762        `checkpoint_number` is provided, `save_counter` is still incremented. A
763        user-provided `checkpoint_number` is not incremented even if it is a
764        `Variable`.
765      check_interval: An optional boolean. The argument is only effective when
766        `checkpoint_interval` is passed into the manager. If `True`, the manager
767        will only save the checkpoint if the interval between checkpoints is
768        larger than `checkpoint_interval`. Otherwise it will always save the
769        checkpoint unless a checkpoint has already been saved for the current
770        step.
771      options: Optional `tf.train.CheckpointOptions` object. This argument only
772        works with TF2 checkpoint objects. For example, options =
773        tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
774
775    Returns:
776      The path to the new checkpoint. It is also recorded in the `checkpoints`
777      and `latest_checkpoint` properties. `None` if no checkpoint is saved.
778    """
779    if self._checkpoint_interval is not None:
780      current_step = _evaluate(self._step_counter)
781      if self._last_checkpoint_step is not None:
782        if current_step == self._last_checkpoint_step:
783          return None
784        if check_interval and current_step < (
785            self._last_checkpoint_step + self._checkpoint_interval):
786          return None
787      self._last_checkpoint_step = current_step
788
789    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
790    # slightly with a custom numbering option.
791    if context.executing_eagerly():
792      save_counter = self._checkpoint.save_counter
793      save_counter.assign_add(1)
794      session = None
795    else:
796      session = ops.get_default_session()
797
798      def _initializing_creator(next_creator, **kwargs):
799        """Initialize the save counter if it has been newly created."""
800        v = next_creator(**kwargs)
801        session.run(v.initializer)
802        return v
803
804      with variable_scope.variable_creator_scope(_initializing_creator):
805        save_counter = self._checkpoint.save_counter
806      if self._save_counter_assign is None:
807        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
808      session.run(self._save_counter_assign)
809    if checkpoint_number is None:
810      checkpoint_number = save_counter
811    if not isinstance(checkpoint_number, compat.integral_types):
812      checkpoint_number = training_util.global_step(
813          sess=session, global_step_tensor=checkpoint_number)
814    prefix = "%s-%d" % (self._prefix, checkpoint_number)
815    if options is None:
816      save_path = self._checkpoint.write(prefix)
817    else:
818      save_path = self._checkpoint.write(prefix, options=options)
819    timestamp = time.time()
820    # If this is an overwritten checkpoint we were previously tracking, delete
821    # and reinsert it to make sure it goes to the end of the queue.
822    if save_path in self._maybe_delete:
823      del self._maybe_delete[save_path]
824    self._maybe_delete[save_path] = timestamp
825    self._latest_checkpoint = save_path
826    # Before deleting anything we update the Checkpoint proto with the new
827    # checkpoint. We'll go back and correct it after cleaning up old files, but
828    # a preemption while deleting will be more likely to see the new checkpoint
829    # this way.
830    self._record_state()
831    self._sweep()
832    # Write out the Checkpoint proto a second time, now without the deleted
833    # checkpoints.
834    self._record_state()
835    return save_path
836
837  def restore_or_initialize(self):
838    """Restore items in `checkpoint` from the latest checkpoint file.
839
840    This method will first try to restore from the most recent checkpoint in
841    `directory`. If no checkpoints exist in `directory`, and `init_fn` is
842    specified, this method will call `init_fn` to do customized
843    initialization. This can be used to support initialization from pretrained
844    models.
845
846    Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return
847    a load status object that users can run assertions on
848    (e.g. assert_consumed()). Thus to run assertions, users should directly use
849    `tf.train.Checkpoint.restore()` method.
850
851    Returns:
852      The restored checkpoint path if the lastest checkpoint is found and
853      restored. Otherwise None.
854    """
855    if self._latest_checkpoint is not None:
856      self._checkpoint.restore(self._latest_checkpoint)
857      if self._checkpoint_interval is not None:
858        self._last_checkpoint_step = _evaluate(self._step_counter)
859      return self._latest_checkpoint
860
861    if self._init_fn is not None:
862      self._init_fn()
863    return None
864