• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import os.path
24import re
25import time
26
27from google.protobuf import text_format
28
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.eager import context
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.lib.io import file_io
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.training import training_util
37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
38from tensorflow.python.util import compat
39from tensorflow.python.util import deprecation
40from tensorflow.python.util.tf_export import tf_export
41
42
43def _evaluate(tensor):
44  """Returns the numpy value of a tensor."""
45  if context.executing_eagerly():
46    return tensor.numpy()
47  return ops.get_default_session().run(tensor)
48
49
50def _GetCheckpointFilename(save_dir, latest_filename):
51  """Returns a filename for storing the CheckpointState.
52
53  Args:
54    save_dir: The directory for saving and restoring checkpoints.
55    latest_filename: Name of the file in 'save_dir' that is used
56      to store the CheckpointState.
57
58  Returns:
59    The path of the file that contains the CheckpointState proto.
60  """
61  if latest_filename is None:
62    latest_filename = "checkpoint"
63  return os.path.join(save_dir, latest_filename)
64
65
66@tf_export(v1=["train.generate_checkpoint_state_proto"])
67def generate_checkpoint_state_proto(save_dir,
68                                    model_checkpoint_path,
69                                    all_model_checkpoint_paths=None,
70                                    all_model_checkpoint_timestamps=None,
71                                    last_preserved_timestamp=None):
72  """Generates a checkpoint state proto.
73
74  Args:
75    save_dir: Directory where the model was saved.
76    model_checkpoint_path: The checkpoint file.
77    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
78      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
79      the last element must be equal to model_checkpoint_path.  These paths
80      are also saved in the CheckpointState proto.
81    all_model_checkpoint_timestamps: A list of floats, indicating the number of
82      seconds since the Epoch when each checkpoint was generated.
83    last_preserved_timestamp: A float, indicating the number of seconds since
84      the Epoch when the last preserved checkpoint was written, e.g. due to a
85      `keep_checkpoint_every_n_hours` parameter (see
86      `tf.train.CheckpointManager` for an implementation).
87  Returns:
88    CheckpointState proto with model_checkpoint_path and
89    all_model_checkpoint_paths updated to either absolute paths or
90    relative paths to the current save_dir.
91
92  Raises:
93    ValueError: If `all_model_checkpoint_timestamps` was provided but its length
94      does not match `all_model_checkpoint_paths`.
95  """
96  if all_model_checkpoint_paths is None:
97    all_model_checkpoint_paths = []
98
99  if (not all_model_checkpoint_paths or
100      all_model_checkpoint_paths[-1] != model_checkpoint_path):
101    logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
102                 model_checkpoint_path)
103    all_model_checkpoint_paths.append(model_checkpoint_path)
104
105  if (all_model_checkpoint_timestamps
106      and (len(all_model_checkpoint_timestamps)
107           != len(all_model_checkpoint_paths))):
108    raise ValueError(
109        ("Checkpoint timestamps, if provided, must match checkpoint paths (got "
110         "paths %s and timestamps %s)")
111        % (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
112
113  # Relative paths need to be rewritten to be relative to the "save_dir"
114  # if model_checkpoint_path already contains "save_dir".
115  if not os.path.isabs(save_dir):
116    if not os.path.isabs(model_checkpoint_path):
117      model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
118    for i, p in enumerate(all_model_checkpoint_paths):
119      if not os.path.isabs(p):
120        all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
121
122  coord_checkpoint_proto = CheckpointState(
123      model_checkpoint_path=model_checkpoint_path,
124      all_model_checkpoint_paths=all_model_checkpoint_paths,
125      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
126      last_preserved_timestamp=last_preserved_timestamp)
127
128  return coord_checkpoint_proto
129
130
131@deprecation.deprecated(
132    date=None,
133    instructions=("Use `tf.train.CheckpointManager` to manage checkpoints "
134                  "rather than manually editing the Checkpoint proto."))
135@tf_export(v1=["train.update_checkpoint_state"])
136def update_checkpoint_state(save_dir,
137                            model_checkpoint_path,
138                            all_model_checkpoint_paths=None,
139                            latest_filename=None,
140                            all_model_checkpoint_timestamps=None,
141                            last_preserved_timestamp=None):
142  """Updates the content of the 'checkpoint' file.
143
144  This updates the checkpoint file containing a CheckpointState
145  proto.
146
147  Args:
148    save_dir: Directory where the model was saved.
149    model_checkpoint_path: The checkpoint file.
150    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
151      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
152      the last element must be equal to model_checkpoint_path.  These paths
153      are also saved in the CheckpointState proto.
154    latest_filename: Optional name of the checkpoint file.  Default to
155      'checkpoint'.
156    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
157      seconds since the Epoch) indicating when the checkpoints in
158      `all_model_checkpoint_paths` were created.
159    last_preserved_timestamp: A float, indicating the number of seconds since
160      the Epoch when the last preserved checkpoint was written, e.g. due to a
161      `keep_checkpoint_every_n_hours` parameter (see
162      `tf.train.CheckpointManager` for an implementation).
163  Raises:
164    RuntimeError: If any of the model checkpoint paths conflict with the file
165      containing CheckpointSate.
166  """
167  update_checkpoint_state_internal(
168      save_dir=save_dir,
169      model_checkpoint_path=model_checkpoint_path,
170      all_model_checkpoint_paths=all_model_checkpoint_paths,
171      latest_filename=latest_filename,
172      save_relative_paths=False,
173      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
174      last_preserved_timestamp=last_preserved_timestamp)
175
176
177@tf_export("__internal__.train.update_checkpoint_state", v1=[])
178def update_checkpoint_state_internal(save_dir,
179                                     model_checkpoint_path,
180                                     all_model_checkpoint_paths=None,
181                                     latest_filename=None,
182                                     save_relative_paths=False,
183                                     all_model_checkpoint_timestamps=None,
184                                     last_preserved_timestamp=None):
185  """Updates the content of the 'checkpoint' file.
186
187  This updates the checkpoint file containing a CheckpointState
188  proto.
189
190  Args:
191    save_dir: Directory where the model was saved.
192    model_checkpoint_path: The checkpoint file.
193    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
194      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
195      the last element must be equal to model_checkpoint_path.  These paths
196      are also saved in the CheckpointState proto.
197    latest_filename: Optional name of the checkpoint file.  Default to
198      'checkpoint'.
199    save_relative_paths: If `True`, will write relative paths to the checkpoint
200      state file.
201    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
202      seconds since the Epoch) indicating when the checkpoints in
203      `all_model_checkpoint_paths` were created.
204    last_preserved_timestamp: A float, indicating the number of seconds since
205      the Epoch when the last preserved checkpoint was written, e.g. due to a
206      `keep_checkpoint_every_n_hours` parameter (see
207      `tf.train.CheckpointManager` for an implementation).
208
209  Raises:
210    RuntimeError: If any of the model checkpoint paths conflict with the file
211      containing CheckpointSate.
212  """
213  # Writes the "checkpoint" file for the coordinator for later restoration.
214  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
215  if save_relative_paths:
216    if os.path.isabs(model_checkpoint_path):
217      rel_model_checkpoint_path = os.path.relpath(
218          model_checkpoint_path, save_dir)
219    else:
220      rel_model_checkpoint_path = model_checkpoint_path
221    rel_all_model_checkpoint_paths = []
222    for p in all_model_checkpoint_paths:
223      if os.path.isabs(p):
224        rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
225      else:
226        rel_all_model_checkpoint_paths.append(p)
227    ckpt = generate_checkpoint_state_proto(
228        save_dir,
229        rel_model_checkpoint_path,
230        all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
231        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
232        last_preserved_timestamp=last_preserved_timestamp)
233  else:
234    ckpt = generate_checkpoint_state_proto(
235        save_dir,
236        model_checkpoint_path,
237        all_model_checkpoint_paths=all_model_checkpoint_paths,
238        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
239        last_preserved_timestamp=last_preserved_timestamp)
240
241  if coord_checkpoint_filename == ckpt.model_checkpoint_path:
242    raise RuntimeError("Save path '%s' conflicts with path used for "
243                       "checkpoint state.  Please use a different save path." %
244                       model_checkpoint_path)
245
246  # Preventing potential read/write race condition by *atomically* writing to a
247  # file.
248  file_io.atomic_write_string_to_file(coord_checkpoint_filename,
249                                      text_format.MessageToString(ckpt))
250
251
252@tf_export("train.get_checkpoint_state")
253def get_checkpoint_state(checkpoint_dir, latest_filename=None):
254  """Returns CheckpointState proto from the "checkpoint" file.
255
256  If the "checkpoint" file contains a valid CheckpointState
257  proto, returns it.
258
259  Args:
260    checkpoint_dir: The directory of checkpoints.
261    latest_filename: Optional name of the checkpoint file.  Default to
262      'checkpoint'.
263
264  Returns:
265    A CheckpointState if the state was available, None
266    otherwise.
267
268  Raises:
269    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
270  """
271  ckpt = None
272  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
273                                                     latest_filename)
274  f = None
275  try:
276    # Check that the file exists before opening it to avoid
277    # many lines of errors from colossus in the logs.
278    if file_io.file_exists(coord_checkpoint_filename):
279      file_content = file_io.read_file_to_string(
280          coord_checkpoint_filename)
281      ckpt = CheckpointState()
282      text_format.Merge(file_content, ckpt)
283      if not ckpt.model_checkpoint_path:
284        raise ValueError("Invalid checkpoint state loaded from "
285                         + checkpoint_dir)
286      # For relative model_checkpoint_path and all_model_checkpoint_paths,
287      # prepend checkpoint_dir.
288      if not os.path.isabs(ckpt.model_checkpoint_path):
289        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
290                                                  ckpt.model_checkpoint_path)
291      for i, p in enumerate(ckpt.all_model_checkpoint_paths):
292        if not os.path.isabs(p):
293          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
294  except errors.OpError as e:
295    # It's ok if the file cannot be read
296    logging.warning("%s: %s", type(e).__name__, e)
297    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
298    return None
299  except text_format.ParseError as e:
300    logging.warning("%s: %s", type(e).__name__, e)
301    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
302    return None
303  finally:
304    if f:
305      f.close()
306  return ckpt
307
308
309def _prefix_to_checkpoint_path(prefix, format_version):
310  """Returns the pathname of a checkpoint file, given the checkpoint prefix.
311
312  For V1 checkpoint, simply returns the prefix itself (the data file).  For V2,
313  returns the pathname to the index file.
314
315  Args:
316    prefix: a string, the prefix of a checkpoint.
317    format_version: the checkpoint format version that corresponds to the
318      prefix.
319  Returns:
320    The pathname of a checkpoint file, taking into account the checkpoint
321      format version.
322  """
323  if format_version == saver_pb2.SaverDef.V2:
324    return prefix + ".index"  # The index file identifies a checkpoint.
325  return prefix  # Just the data file.
326
327
328@tf_export("train.latest_checkpoint")
329def latest_checkpoint(checkpoint_dir, latest_filename=None):
330  """Finds the filename of latest saved checkpoint file.
331
332  Gets the checkpoint state given the provided checkpoint_dir and looks for a
333  corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path.
334  The latest_filename argument is only applicable if you are saving checkpoint
335  using `v1.train.Saver.save`
336
337
338  See the [Training Checkpoints
339  Guide](https://www.tensorflow.org/guide/checkpoint) for more details and
340  examples.`
341
342  Args:
343    checkpoint_dir: Directory where the variables were saved.
344    latest_filename: Optional name for the protocol buffer file that
345      contains the list of most recent checkpoint filenames.
346      See the corresponding argument to `v1.train.Saver.save`.
347
348  Returns:
349    The full path to the latest checkpoint or `None` if no checkpoint was found.
350  """
351  # Pick the latest checkpoint based on checkpoint state.
352  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
353  if ckpt and ckpt.model_checkpoint_path:
354    # Look for either a V2 path or a V1 path, with priority for V2.
355    v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
356                                         saver_pb2.SaverDef.V2)
357    v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
358                                         saver_pb2.SaverDef.V1)
359    if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
360        v1_path):
361      return ckpt.model_checkpoint_path
362    else:
363      logging.error("Couldn't match files for checkpoint %s",
364                    ckpt.model_checkpoint_path)
365  return None
366
367
368def checkpoint_exists_internal(checkpoint_prefix):
369  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
370
371  This is an internal function to check if a checkpoint exists,
372  since it takes into account the naming difference between V1 and V2 formats.
373
374  Args:
375    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
376      priority.  Typically the result of `Saver.save()` or that of
377      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
378      V1/V2.
379  Returns:
380    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
381  """
382  pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
383                                        saver_pb2.SaverDef.V2)
384  if file_io.get_matching_files(pathname):
385    return True
386  elif file_io.get_matching_files(checkpoint_prefix):
387    return True
388  else:
389    return False
390
391
392@deprecation.deprecated(
393    date=None,
394    instructions="Use standard file APIs to check for files with this prefix.")
395@tf_export(v1=["train.checkpoint_exists"])
396def checkpoint_exists(checkpoint_prefix):
397  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
398
399  This is the recommended way to check if a checkpoint exists, since it takes
400  into account the naming difference between V1 and V2 formats.
401
402  Args:
403    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
404      priority.  Typically the result of `Saver.save()` or that of
405      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
406      V1/V2.
407
408  Returns:
409    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
410  """
411  return checkpoint_exists_internal(checkpoint_prefix)
412
413
414@deprecation.deprecated(
415    date=None,
416    instructions="Use standard file utilities to get mtimes.")
417@tf_export(v1=["train.get_checkpoint_mtimes"])
418def get_checkpoint_mtimes(checkpoint_prefixes):
419  """Returns the mtimes (modification timestamps) of the checkpoints.
420
421  Globs for the checkpoints pointed to by `checkpoint_prefixes`.  If the files
422  exist, collect their mtime.  Both V2 and V1 checkpoints are considered, in
423  that priority.
424
425  This is the recommended way to get the mtimes, since it takes into account
426  the naming difference between V1 and V2 formats.
427
428  Note: If not all checkpoints exist, the length of the returned mtimes list
429  will be smaller than the length of `checkpoint_prefixes` list, so mapping
430  checkpoints to corresponding mtimes will not be possible.
431
432  Args:
433    checkpoint_prefixes: a list of checkpoint paths, typically the results of
434      `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
435      sharded/non-sharded or V1/V2.
436  Returns:
437    A list of mtimes (in microseconds) of the found checkpoints.
438  """
439  mtimes = []
440
441  def match_maybe_append(pathname):
442    fnames = file_io.get_matching_files(pathname)
443    if fnames:
444      mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
445      return True
446    return False
447
448  for checkpoint_prefix in checkpoint_prefixes:
449    # Tries V2's metadata file first.
450    pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
451                                          saver_pb2.SaverDef.V2)
452    if match_maybe_append(pathname):
453      continue
454    # Otherwise, tries V1, where the prefix is the complete pathname.
455    match_maybe_append(checkpoint_prefix)
456
457  return mtimes
458
459
460@deprecation.deprecated(
461    date=None,
462    instructions="Use standard file APIs to delete files with this prefix.")
463@tf_export(v1=["train.remove_checkpoint"])
464def remove_checkpoint(checkpoint_prefix,
465                      checkpoint_format_version=saver_pb2.SaverDef.V2,
466                      meta_graph_suffix="meta"):
467  """Removes a checkpoint given by `checkpoint_prefix`.
468
469  Args:
470    checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
471      of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
472      sharded/non-sharded or V1/V2.
473    checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
474      `SaverDef.V2`.
475    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
476  """
477  _delete_file_if_exists(
478      meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
479  if checkpoint_format_version == saver_pb2.SaverDef.V2:
480    # V2 has a metadata file and some data files.
481    _delete_file_if_exists(checkpoint_prefix + ".index")
482    _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
483  else:
484    # V1, Legacy.  Exact match on the data file.
485    _delete_file_if_exists(checkpoint_prefix)
486
487
488def _delete_file_if_exists(filespec):
489  """Deletes files matching `filespec`."""
490  for pathname in file_io.get_matching_files(filespec):
491    try:
492      file_io.delete_file(pathname)
493    except errors.NotFoundError:
494      logging.warning(
495          "Hit NotFoundError when deleting '%s', possibly because another "
496          "process/thread is also deleting/moving the same file", pathname)
497
498
499def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
500  """Returns the meta graph filename.
501
502  Args:
503    checkpoint_filename: Name of the checkpoint file.
504    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
505
506  Returns:
507    MetaGraph file name.
508  """
509  # If the checkpoint_filename is sharded, the checkpoint_filename could
510  # be of format model.ckpt-step#-?????-of-shard#. For example,
511  # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
512  basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
513  suffixed_filename = ".".join([basename, meta_graph_suffix])
514  return suffixed_filename
515
516
517# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
518@tf_export("train.CheckpointManager")
519class CheckpointManager(object):
520  """Manages multiple checkpoints by keeping some and deleting unneeded ones.
521
522  Example usage:
523
524  ```python
525  import tensorflow as tf
526  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
527  manager = tf.train.CheckpointManager(
528      checkpoint, directory="/tmp/model", max_to_keep=5)
529  status = checkpoint.restore(manager.latest_checkpoint)
530  while True:
531    # train
532    manager.save()
533  ```
534
535  `CheckpointManager` preserves its own state across instantiations (see the
536  `__init__` documentation for details). Only one should be active in a
537  particular directory at a time.
538  """
539
540  def __init__(self,
541               checkpoint,
542               directory,
543               max_to_keep,
544               keep_checkpoint_every_n_hours=None,
545               checkpoint_name="ckpt",
546               step_counter=None,
547               checkpoint_interval=None,
548               init_fn=None):
549    """Configure a `CheckpointManager` for use in `directory`.
550
551    If a `CheckpointManager` was previously used in `directory`, its
552    state will be restored. This includes the list of managed checkpoints and
553    the timestamp bookkeeping necessary to support
554    `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
555    will be the same as the previous `CheckpointManager`, including cleaning up
556    existing checkpoints if appropriate.
557
558    Checkpoints are only considered for deletion just after a new checkpoint has
559    been added. At that point, `max_to_keep` checkpoints will remain in an
560    "active set". Once a checkpoint is preserved by
561    `keep_checkpoint_every_n_hours` it will not be deleted by this
562    `CheckpointManager` or any future `CheckpointManager` instantiated in
563    `directory` (regardless of the new setting of
564    `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
565    active set may be deleted by this `CheckpointManager` or a future
566    `CheckpointManager` instantiated in `directory` (subject to its
567    `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
568
569    `CheckpointManager` can be also used for initializing the model if
570    there is no checkpoints for restoring in `directory`. An example usage is:
571
572    >>> import tempfile
573
574    >>> tmp_dir = tempfile.mkdtemp()
575    >>> checkpoint = tf.train.Checkpoint()
576    >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init'))
577
578    >>> def init_fn():
579    ...   # Partially restore the checkpoint from `init_path`.
580    ...   checkpoint.restore(init_path)
581
582    >>> manager = tf.train.CheckpointManager(
583    ...     checkpoint,
584    ...     directory=os.path.join(tmp_dir, 'ckpt'),
585    ...     max_to_keep=None,
586    ...     init_fn=init_fn)
587    >>> # `restore_or_initialize` will call `init_fn` if there is no existing
588    >>> # checkpoint in `directory`.
589    >>> manager.restore_or_initialize()
590
591    Args:
592      checkpoint: The `tf.train.Checkpoint` instance to save and manage
593        checkpoints for.
594      directory: The path to a directory in which to write checkpoints. A
595        special file named "checkpoint" is also written to this directory (in a
596        human-readable text format) which contains the state of the
597        `CheckpointManager`.
598      max_to_keep: An integer, the number of checkpoints to keep. Unless
599        preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
600        deleted from the active set, oldest first, until only `max_to_keep`
601        checkpoints remain. If `None`, no checkpoints are deleted and everything
602        stays in the active set. Note that `max_to_keep=None` will keep all
603        checkpoint paths in memory and in the checkpoint state protocol buffer
604        on disk.
605      keep_checkpoint_every_n_hours: Upon removal from the active set, a
606        checkpoint will be preserved if it has been at least
607        `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
608        default setting of `None` does not preserve any checkpoints in this way.
609      checkpoint_name: Custom name for the checkpoint file.
610      step_counter: A `tf.Variable` instance for checking the current step
611        counter value, in case users want to save checkpoints every N steps.
612      checkpoint_interval: An integer, indicates the minimum step interval
613        between two checkpoints.
614      init_fn: Callable. A function to do customized intialization if no
615        checkpoints are in the directory.
616
617    Raises:
618      ValueError: If `max_to_keep` is not a positive integer.
619    """
620    self._checkpoint = checkpoint
621    self._save_counter_assign = None
622    if max_to_keep is not None and max_to_keep <= 0:
623      raise ValueError(
624          ("Expected a positive integer or `None` for `max_to_keep`, "
625           "got %d.")
626          % (max_to_keep,))
627    self._max_to_keep = max_to_keep
628    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
629    self._directory = directory
630    self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
631    self._init_fn = init_fn
632
633    if checkpoint_interval is not None:
634      if step_counter is None:
635        raise ValueError("`step_counter` should be passed if "
636                         "`checkpoint_interval` is not None.")
637      self._last_checkpoint_step = None
638      self._step_counter = step_counter
639    self._checkpoint_interval = checkpoint_interval
640
641    recovered_state = get_checkpoint_state(directory)
642    current_clock = time.time()
643    self._maybe_delete = collections.OrderedDict()
644    if recovered_state is None:
645      self._latest_checkpoint = None
646      # Set the clock back slightly to avoid race conditions when quickly
647      # re-creating a CheckpointManager.
648      self._last_preserved_timestamp = current_clock - 1.
649    else:
650      self._latest_checkpoint = recovered_state.model_checkpoint_path
651      self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
652      if current_clock < self._last_preserved_timestamp:
653        # Time seems to have reversed itself. In addition to this warning, we'll
654        # min() saved checkpoint timestamps with the current time to ensure that
655        # old checkpoints don't get deleted accidentally.
656        logging.warning(
657            ("time.time() returned a value %f seconds behind the last "
658             "preserved checkpoint timestamp.")
659            % (self._last_preserved_timestamp - current_clock,))
660        self._last_preserved_timestamp = current_clock
661      all_timestamps = recovered_state.all_model_checkpoint_timestamps
662      all_paths = recovered_state.all_model_checkpoint_paths
663      del recovered_state  # Uses modified values from now on
664      if not all_timestamps:
665        all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
666
667      for filename, timestamp in zip(all_paths, all_timestamps):
668        timestamp = min(timestamp, current_clock)
669        if timestamp > self._last_preserved_timestamp:
670          self._maybe_delete[filename] = timestamp
671
672  @property
673  def directory(self):
674    return self._directory
675
676  @property
677  def checkpoint_interval(self):
678    return self._checkpoint_interval
679
680  @property
681  def latest_checkpoint(self):
682    """The prefix of the most recent checkpoint in `directory`.
683
684    Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
685    the constructor argument to `CheckpointManager`.
686
687    Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
688
689    Returns:
690      The checkpoint prefix. If there are no checkpoints, returns `None`.
691    """
692    return self._latest_checkpoint
693
694  @property
695  def checkpoints(self):
696    """A list of managed checkpoints.
697
698    Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
699    show up in this list (to avoid ever-growing filename lists).
700
701    Returns:
702      A list of filenames, sorted from oldest to newest.
703    """
704    return list(self._maybe_delete.keys())
705
706  def _sweep(self):
707    """Deletes or preserves managed checkpoints."""
708    if not self._max_to_keep:
709      # Does not update self._last_preserved_timestamp, since everything is kept
710      # in the active set.
711      return
712    while len(self._maybe_delete) > self._max_to_keep:
713      filename, timestamp = self._maybe_delete.popitem(last=False)
714      # Even if we're keeping this checkpoint due to
715      # keep_checkpoint_every_n_hours, we won't reference it to avoid
716      # infinitely-growing CheckpointState protos.
717      if (self._keep_checkpoint_every_n_hours
718          and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
719               >= self._last_preserved_timestamp)):
720        self._last_preserved_timestamp = timestamp
721        continue
722      _delete_file_if_exists(filename + ".index")
723      _delete_file_if_exists(filename + ".data-?????-of-?????")
724
725  def _record_state(self):
726    """Saves the `CheckpointManager`'s state in `directory`."""
727    filenames, timestamps = zip(*self._maybe_delete.items())
728    update_checkpoint_state_internal(
729        self._directory,
730        model_checkpoint_path=self.latest_checkpoint,
731        all_model_checkpoint_paths=filenames,
732        all_model_checkpoint_timestamps=timestamps,
733        last_preserved_timestamp=self._last_preserved_timestamp,
734        save_relative_paths=True)
735
736  @property
737  def _prefix(self):
738    """A common prefix for all checkpoints saved with this manager.
739
740    For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
741    `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
742    numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
743    checkpoint has several associated files
744    (e.g. `"/tmp/tf-model/ckpt-2.index"`).
745
746    Returns:
747      A string prefix.
748    """
749    return self._checkpoint_prefix
750
751  @property
752  def checkpoint(self):
753    """Returns the `tf.train.Checkpoint` object."""
754    return self._checkpoint
755
756  def save(self, checkpoint_number=None, check_interval=True, options=None):
757    """Creates a new checkpoint and manages it.
758
759    Args:
760      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
761        `Tensor`, used to number the checkpoint. If `None` (default),
762        checkpoints are numbered using `checkpoint.save_counter`. Even if
763        `checkpoint_number` is provided, `save_counter` is still incremented. A
764        user-provided `checkpoint_number` is not incremented even if it is a
765        `Variable`.
766      check_interval: An optional boolean. The argument is only effective when
767        `checkpoint_interval` is passed into the manager. If `True`, the manager
768        will only save the checkpoint if the interval between checkpoints is
769        larger than `checkpoint_interval`. Otherwise it will always save the
770        checkpoint unless a checkpoint has already been saved for the current
771        step.
772      options: Optional `tf.train.CheckpointOptions` object. This argument only
773        works with TF2 checkpoint objects. For example, options =
774        tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
775
776    Returns:
777      The path to the new checkpoint. It is also recorded in the `checkpoints`
778      and `latest_checkpoint` properties. `None` if no checkpoint is saved.
779    """
780    if self._checkpoint_interval is not None:
781      current_step = _evaluate(self._step_counter)
782      if self._last_checkpoint_step is not None:
783        if current_step == self._last_checkpoint_step:
784          return None
785        if check_interval and current_step < (
786            self._last_checkpoint_step + self._checkpoint_interval):
787          return None
788      self._last_checkpoint_step = current_step
789
790    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
791    # slightly with a custom numbering option.
792    if context.executing_eagerly():
793      save_counter = self._checkpoint.save_counter
794      save_counter.assign_add(1)
795      session = None
796    else:
797      session = ops.get_default_session()
798
799      def _initializing_creator(next_creator, **kwargs):
800        """Initialize the save counter if it has been newly created."""
801        v = next_creator(**kwargs)
802        session.run(v.initializer)
803        return v
804
805      with variable_scope.variable_creator_scope(_initializing_creator):
806        save_counter = self._checkpoint.save_counter
807      if self._save_counter_assign is None:
808        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
809      session.run(self._save_counter_assign)
810    if checkpoint_number is None:
811      checkpoint_number = save_counter
812    if not isinstance(checkpoint_number, compat.integral_types):
813      checkpoint_number = training_util.global_step(
814          sess=session, global_step_tensor=checkpoint_number)
815    prefix = "%s-%d" % (self._prefix, checkpoint_number)
816    if options is None:
817      save_path = self._checkpoint.write(prefix)
818    else:
819      save_path = self._checkpoint.write(prefix, options=options)
820    timestamp = time.time()
821    # If this is an overwritten checkpoint we were previously tracking, delete
822    # and reinsert it to make sure it goes to the end of the queue.
823    if save_path in self._maybe_delete:
824      del self._maybe_delete[save_path]
825    self._maybe_delete[save_path] = timestamp
826    self._latest_checkpoint = save_path
827    # Before deleting anything we update the Checkpoint proto with the new
828    # checkpoint. We'll go back and correct it after cleaning up old files, but
829    # a preemption while deleting will be more likely to see the new checkpoint
830    # this way.
831    self._record_state()
832    self._sweep()
833    # Write out the Checkpoint proto a second time, now without the deleted
834    # checkpoints.
835    self._record_state()
836    return save_path
837
838  def restore_or_initialize(self):
839    """Restore items in `checkpoint` from the latest checkpoint file.
840
841    This method will first try to restore from the most recent checkpoint in
842    `directory`. If no checkpoints exist in `directory`, and `init_fn` is
843    specified, this method will call `init_fn` to do customized
844    initialization. This can be used to support initialization from pretrained
845    models.
846
847    Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return
848    a load status object that users can run assertions on
849    (e.g. assert_consumed()). Thus to run assertions, users should directly use
850    `tf.train.Checkpoint.restore()` method.
851
852    Returns:
853      The restored checkpoint path if the lastest checkpoint is found and
854      restored. Otherwise None.
855    """
856    if self._latest_checkpoint is not None:
857      self._checkpoint.restore(self._latest_checkpoint)
858      if self._checkpoint_interval is not None:
859        self._last_checkpoint_step = _evaluate(self._step_counter)
860      return self._latest_checkpoint
861
862    if self._init_fn is not None:
863      self._init_fn()
864    return None
865