1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import os.path 24import re 25import time 26 27from google.protobuf import text_format 28 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.eager import context 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.lib.io import file_io 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.training import training_util 37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 38from tensorflow.python.util import compat 39from tensorflow.python.util import deprecation 40from tensorflow.python.util.tf_export import tf_export 41 42 43def _evaluate(tensor): 44 """Returns the numpy value of a tensor.""" 45 if context.executing_eagerly(): 46 return tensor.numpy() 47 return ops.get_default_session().run(tensor) 48 49 50def _GetCheckpointFilename(save_dir, latest_filename): 51 """Returns a filename for storing the CheckpointState. 52 53 Args: 54 save_dir: The directory for saving and restoring checkpoints. 55 latest_filename: Name of the file in 'save_dir' that is used 56 to store the CheckpointState. 57 58 Returns: 59 The path of the file that contains the CheckpointState proto. 60 """ 61 if latest_filename is None: 62 latest_filename = "checkpoint" 63 return os.path.join(save_dir, latest_filename) 64 65 66@tf_export(v1=["train.generate_checkpoint_state_proto"]) 67def generate_checkpoint_state_proto(save_dir, 68 model_checkpoint_path, 69 all_model_checkpoint_paths=None, 70 all_model_checkpoint_timestamps=None, 71 last_preserved_timestamp=None): 72 """Generates a checkpoint state proto. 73 74 Args: 75 save_dir: Directory where the model was saved. 76 model_checkpoint_path: The checkpoint file. 77 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 78 checkpoints, sorted from oldest to newest. If this is a non-empty list, 79 the last element must be equal to model_checkpoint_path. These paths 80 are also saved in the CheckpointState proto. 81 all_model_checkpoint_timestamps: A list of floats, indicating the number of 82 seconds since the Epoch when each checkpoint was generated. 83 last_preserved_timestamp: A float, indicating the number of seconds since 84 the Epoch when the last preserved checkpoint was written, e.g. due to a 85 `keep_checkpoint_every_n_hours` parameter (see 86 `tf.train.CheckpointManager` for an implementation). 87 Returns: 88 CheckpointState proto with model_checkpoint_path and 89 all_model_checkpoint_paths updated to either absolute paths or 90 relative paths to the current save_dir. 91 92 Raises: 93 ValueError: If `all_model_checkpoint_timestamps` was provided but its length 94 does not match `all_model_checkpoint_paths`. 95 """ 96 if all_model_checkpoint_paths is None: 97 all_model_checkpoint_paths = [] 98 99 if (not all_model_checkpoint_paths or 100 all_model_checkpoint_paths[-1] != model_checkpoint_path): 101 logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", 102 model_checkpoint_path) 103 all_model_checkpoint_paths.append(model_checkpoint_path) 104 105 if (all_model_checkpoint_timestamps 106 and (len(all_model_checkpoint_timestamps) 107 != len(all_model_checkpoint_paths))): 108 raise ValueError( 109 ("Checkpoint timestamps, if provided, must match checkpoint paths (got " 110 "paths %s and timestamps %s)") 111 % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) 112 113 # Relative paths need to be rewritten to be relative to the "save_dir" 114 # if model_checkpoint_path already contains "save_dir". 115 if not os.path.isabs(save_dir): 116 if not os.path.isabs(model_checkpoint_path): 117 model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) 118 for i, p in enumerate(all_model_checkpoint_paths): 119 if not os.path.isabs(p): 120 all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) 121 122 coord_checkpoint_proto = CheckpointState( 123 model_checkpoint_path=model_checkpoint_path, 124 all_model_checkpoint_paths=all_model_checkpoint_paths, 125 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 126 last_preserved_timestamp=last_preserved_timestamp) 127 128 return coord_checkpoint_proto 129 130 131@deprecation.deprecated( 132 date=None, 133 instructions=("Use `tf.train.CheckpointManager` to manage checkpoints " 134 "rather than manually editing the Checkpoint proto.")) 135@tf_export(v1=["train.update_checkpoint_state"]) 136def update_checkpoint_state(save_dir, 137 model_checkpoint_path, 138 all_model_checkpoint_paths=None, 139 latest_filename=None, 140 all_model_checkpoint_timestamps=None, 141 last_preserved_timestamp=None): 142 """Updates the content of the 'checkpoint' file. 143 144 This updates the checkpoint file containing a CheckpointState 145 proto. 146 147 Args: 148 save_dir: Directory where the model was saved. 149 model_checkpoint_path: The checkpoint file. 150 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 151 checkpoints, sorted from oldest to newest. If this is a non-empty list, 152 the last element must be equal to model_checkpoint_path. These paths 153 are also saved in the CheckpointState proto. 154 latest_filename: Optional name of the checkpoint file. Default to 155 'checkpoint'. 156 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 157 seconds since the Epoch) indicating when the checkpoints in 158 `all_model_checkpoint_paths` were created. 159 last_preserved_timestamp: A float, indicating the number of seconds since 160 the Epoch when the last preserved checkpoint was written, e.g. due to a 161 `keep_checkpoint_every_n_hours` parameter (see 162 `tf.train.CheckpointManager` for an implementation). 163 Raises: 164 RuntimeError: If any of the model checkpoint paths conflict with the file 165 containing CheckpointSate. 166 """ 167 update_checkpoint_state_internal( 168 save_dir=save_dir, 169 model_checkpoint_path=model_checkpoint_path, 170 all_model_checkpoint_paths=all_model_checkpoint_paths, 171 latest_filename=latest_filename, 172 save_relative_paths=False, 173 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 174 last_preserved_timestamp=last_preserved_timestamp) 175 176 177def update_checkpoint_state_internal(save_dir, 178 model_checkpoint_path, 179 all_model_checkpoint_paths=None, 180 latest_filename=None, 181 save_relative_paths=False, 182 all_model_checkpoint_timestamps=None, 183 last_preserved_timestamp=None): 184 """Updates the content of the 'checkpoint' file. 185 186 This updates the checkpoint file containing a CheckpointState 187 proto. 188 189 Args: 190 save_dir: Directory where the model was saved. 191 model_checkpoint_path: The checkpoint file. 192 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 193 checkpoints, sorted from oldest to newest. If this is a non-empty list, 194 the last element must be equal to model_checkpoint_path. These paths 195 are also saved in the CheckpointState proto. 196 latest_filename: Optional name of the checkpoint file. Default to 197 'checkpoint'. 198 save_relative_paths: If `True`, will write relative paths to the checkpoint 199 state file. 200 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 201 seconds since the Epoch) indicating when the checkpoints in 202 `all_model_checkpoint_paths` were created. 203 last_preserved_timestamp: A float, indicating the number of seconds since 204 the Epoch when the last preserved checkpoint was written, e.g. due to a 205 `keep_checkpoint_every_n_hours` parameter (see 206 `tf.train.CheckpointManager` for an implementation). 207 208 Raises: 209 RuntimeError: If any of the model checkpoint paths conflict with the file 210 containing CheckpointSate. 211 """ 212 # Writes the "checkpoint" file for the coordinator for later restoration. 213 coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) 214 if save_relative_paths: 215 if os.path.isabs(model_checkpoint_path): 216 rel_model_checkpoint_path = os.path.relpath( 217 model_checkpoint_path, save_dir) 218 else: 219 rel_model_checkpoint_path = model_checkpoint_path 220 rel_all_model_checkpoint_paths = [] 221 for p in all_model_checkpoint_paths: 222 if os.path.isabs(p): 223 rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) 224 else: 225 rel_all_model_checkpoint_paths.append(p) 226 ckpt = generate_checkpoint_state_proto( 227 save_dir, 228 rel_model_checkpoint_path, 229 all_model_checkpoint_paths=rel_all_model_checkpoint_paths, 230 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 231 last_preserved_timestamp=last_preserved_timestamp) 232 else: 233 ckpt = generate_checkpoint_state_proto( 234 save_dir, 235 model_checkpoint_path, 236 all_model_checkpoint_paths=all_model_checkpoint_paths, 237 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 238 last_preserved_timestamp=last_preserved_timestamp) 239 240 if coord_checkpoint_filename == ckpt.model_checkpoint_path: 241 raise RuntimeError("Save path '%s' conflicts with path used for " 242 "checkpoint state. Please use a different save path." % 243 model_checkpoint_path) 244 245 # Preventing potential read/write race condition by *atomically* writing to a 246 # file. 247 file_io.atomic_write_string_to_file(coord_checkpoint_filename, 248 text_format.MessageToString(ckpt)) 249 250 251@tf_export("train.get_checkpoint_state") 252def get_checkpoint_state(checkpoint_dir, latest_filename=None): 253 """Returns CheckpointState proto from the "checkpoint" file. 254 255 If the "checkpoint" file contains a valid CheckpointState 256 proto, returns it. 257 258 Args: 259 checkpoint_dir: The directory of checkpoints. 260 latest_filename: Optional name of the checkpoint file. Default to 261 'checkpoint'. 262 263 Returns: 264 A CheckpointState if the state was available, None 265 otherwise. 266 267 Raises: 268 ValueError: if the checkpoint read doesn't have model_checkpoint_path set. 269 """ 270 ckpt = None 271 coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, 272 latest_filename) 273 f = None 274 try: 275 # Check that the file exists before opening it to avoid 276 # many lines of errors from colossus in the logs. 277 if file_io.file_exists(coord_checkpoint_filename): 278 file_content = file_io.read_file_to_string( 279 coord_checkpoint_filename) 280 ckpt = CheckpointState() 281 text_format.Merge(file_content, ckpt) 282 if not ckpt.model_checkpoint_path: 283 raise ValueError("Invalid checkpoint state loaded from " 284 + checkpoint_dir) 285 # For relative model_checkpoint_path and all_model_checkpoint_paths, 286 # prepend checkpoint_dir. 287 if not os.path.isabs(ckpt.model_checkpoint_path): 288 ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, 289 ckpt.model_checkpoint_path) 290 for i, p in enumerate(ckpt.all_model_checkpoint_paths): 291 if not os.path.isabs(p): 292 ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) 293 except errors.OpError as e: 294 # It's ok if the file cannot be read 295 logging.warning("%s: %s", type(e).__name__, e) 296 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 297 return None 298 except text_format.ParseError as e: 299 logging.warning("%s: %s", type(e).__name__, e) 300 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 301 return None 302 finally: 303 if f: 304 f.close() 305 return ckpt 306 307 308def _prefix_to_checkpoint_path(prefix, format_version): 309 """Returns the pathname of a checkpoint file, given the checkpoint prefix. 310 311 For V1 checkpoint, simply returns the prefix itself (the data file). For V2, 312 returns the pathname to the index file. 313 314 Args: 315 prefix: a string, the prefix of a checkpoint. 316 format_version: the checkpoint format version that corresponds to the 317 prefix. 318 Returns: 319 The pathname of a checkpoint file, taking into account the checkpoint 320 format version. 321 """ 322 if format_version == saver_pb2.SaverDef.V2: 323 return prefix + ".index" # The index file identifies a checkpoint. 324 return prefix # Just the data file. 325 326 327@tf_export("train.latest_checkpoint") 328def latest_checkpoint(checkpoint_dir, latest_filename=None): 329 """Finds the filename of latest saved checkpoint file. 330 331 Gets the checkpoint state given the provided checkpoint_dir and looks for a 332 corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path. 333 The latest_filename argument is only applicable if you are saving checkpoint 334 using `v1.train.Saver.save` 335 336 337 See the [Training Checkpoints 338 Guide](https://www.tensorflow.org/guide/checkpoint) for more details and 339 examples.` 340 341 Args: 342 checkpoint_dir: Directory where the variables were saved. 343 latest_filename: Optional name for the protocol buffer file that 344 contains the list of most recent checkpoint filenames. 345 See the corresponding argument to `v1.train.Saver.save`. 346 347 Returns: 348 The full path to the latest checkpoint or `None` if no checkpoint was found. 349 """ 350 # Pick the latest checkpoint based on checkpoint state. 351 ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) 352 if ckpt and ckpt.model_checkpoint_path: 353 # Look for either a V2 path or a V1 path, with priority for V2. 354 v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 355 saver_pb2.SaverDef.V2) 356 v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 357 saver_pb2.SaverDef.V1) 358 if file_io.get_matching_files(v2_path) or file_io.get_matching_files( 359 v1_path): 360 return ckpt.model_checkpoint_path 361 else: 362 logging.error("Couldn't match files for checkpoint %s", 363 ckpt.model_checkpoint_path) 364 return None 365 366 367def checkpoint_exists_internal(checkpoint_prefix): 368 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 369 370 This is an internal function to check if a checkpoint exists, 371 since it takes into account the naming difference between V1 and V2 formats. 372 373 Args: 374 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 375 priority. Typically the result of `Saver.save()` or that of 376 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 377 V1/V2. 378 Returns: 379 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 380 """ 381 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 382 saver_pb2.SaverDef.V2) 383 if file_io.get_matching_files(pathname): 384 return True 385 elif file_io.get_matching_files(checkpoint_prefix): 386 return True 387 else: 388 return False 389 390 391@deprecation.deprecated( 392 date=None, 393 instructions="Use standard file APIs to check for files with this prefix.") 394@tf_export(v1=["train.checkpoint_exists"]) 395def checkpoint_exists(checkpoint_prefix): 396 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 397 398 This is the recommended way to check if a checkpoint exists, since it takes 399 into account the naming difference between V1 and V2 formats. 400 401 Args: 402 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 403 priority. Typically the result of `Saver.save()` or that of 404 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 405 V1/V2. 406 407 Returns: 408 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 409 """ 410 return checkpoint_exists_internal(checkpoint_prefix) 411 412 413@deprecation.deprecated( 414 date=None, 415 instructions="Use standard file utilities to get mtimes.") 416@tf_export(v1=["train.get_checkpoint_mtimes"]) 417def get_checkpoint_mtimes(checkpoint_prefixes): 418 """Returns the mtimes (modification timestamps) of the checkpoints. 419 420 Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files 421 exist, collect their mtime. Both V2 and V1 checkpoints are considered, in 422 that priority. 423 424 This is the recommended way to get the mtimes, since it takes into account 425 the naming difference between V1 and V2 formats. 426 427 Note: If not all checkpoints exist, the length of the returned mtimes list 428 will be smaller than the length of `checkpoint_prefixes` list, so mapping 429 checkpoints to corresponding mtimes will not be possible. 430 431 Args: 432 checkpoint_prefixes: a list of checkpoint paths, typically the results of 433 `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of 434 sharded/non-sharded or V1/V2. 435 Returns: 436 A list of mtimes (in microseconds) of the found checkpoints. 437 """ 438 mtimes = [] 439 440 def match_maybe_append(pathname): 441 fnames = file_io.get_matching_files(pathname) 442 if fnames: 443 mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) 444 return True 445 return False 446 447 for checkpoint_prefix in checkpoint_prefixes: 448 # Tries V2's metadata file first. 449 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 450 saver_pb2.SaverDef.V2) 451 if match_maybe_append(pathname): 452 continue 453 # Otherwise, tries V1, where the prefix is the complete pathname. 454 match_maybe_append(checkpoint_prefix) 455 456 return mtimes 457 458 459@deprecation.deprecated( 460 date=None, 461 instructions="Use standard file APIs to delete files with this prefix.") 462@tf_export(v1=["train.remove_checkpoint"]) 463def remove_checkpoint(checkpoint_prefix, 464 checkpoint_format_version=saver_pb2.SaverDef.V2, 465 meta_graph_suffix="meta"): 466 """Removes a checkpoint given by `checkpoint_prefix`. 467 468 Args: 469 checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result 470 of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of 471 sharded/non-sharded or V1/V2. 472 checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to 473 `SaverDef.V2`. 474 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 475 """ 476 _delete_file_if_exists( 477 meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) 478 if checkpoint_format_version == saver_pb2.SaverDef.V2: 479 # V2 has a metadata file and some data files. 480 _delete_file_if_exists(checkpoint_prefix + ".index") 481 _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") 482 else: 483 # V1, Legacy. Exact match on the data file. 484 _delete_file_if_exists(checkpoint_prefix) 485 486 487def _delete_file_if_exists(filespec): 488 """Deletes files matching `filespec`.""" 489 for pathname in file_io.get_matching_files(filespec): 490 try: 491 file_io.delete_file(pathname) 492 except errors.NotFoundError: 493 logging.warning( 494 "Hit NotFoundError when deleting '%s', possibly because another " 495 "process/thread is also deleting/moving the same file", pathname) 496 497 498def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): 499 """Returns the meta graph filename. 500 501 Args: 502 checkpoint_filename: Name of the checkpoint file. 503 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 504 505 Returns: 506 MetaGraph file name. 507 """ 508 # If the checkpoint_filename is sharded, the checkpoint_filename could 509 # be of format model.ckpt-step#-?????-of-shard#. For example, 510 # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. 511 basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) 512 suffixed_filename = ".".join([basename, meta_graph_suffix]) 513 return suffixed_filename 514 515 516# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? 517@tf_export("train.CheckpointManager") 518class CheckpointManager(object): 519 """Manages multiple checkpoints by keeping some and deleting unneeded ones. 520 521 Example usage: 522 523 ```python 524 import tensorflow as tf 525 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 526 manager = tf.train.CheckpointManager( 527 checkpoint, directory="/tmp/model", max_to_keep=5) 528 status = checkpoint.restore(manager.latest_checkpoint) 529 while True: 530 # train 531 manager.save() 532 ``` 533 534 `CheckpointManager` preserves its own state across instantiations (see the 535 `__init__` documentation for details). Only one should be active in a 536 particular directory at a time. 537 """ 538 539 def __init__(self, 540 checkpoint, 541 directory, 542 max_to_keep, 543 keep_checkpoint_every_n_hours=None, 544 checkpoint_name="ckpt", 545 step_counter=None, 546 checkpoint_interval=None, 547 init_fn=None): 548 """Configure a `CheckpointManager` for use in `directory`. 549 550 If a `CheckpointManager` was previously used in `directory`, its 551 state will be restored. This includes the list of managed checkpoints and 552 the timestamp bookkeeping necessary to support 553 `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` 554 will be the same as the previous `CheckpointManager`, including cleaning up 555 existing checkpoints if appropriate. 556 557 Checkpoints are only considered for deletion just after a new checkpoint has 558 been added. At that point, `max_to_keep` checkpoints will remain in an 559 "active set". Once a checkpoint is preserved by 560 `keep_checkpoint_every_n_hours` it will not be deleted by this 561 `CheckpointManager` or any future `CheckpointManager` instantiated in 562 `directory` (regardless of the new setting of 563 `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the 564 active set may be deleted by this `CheckpointManager` or a future 565 `CheckpointManager` instantiated in `directory` (subject to its 566 `max_to_keep` and `keep_checkpoint_every_n_hours` settings). 567 568 `CheckpointManager` can be also used for initializing the model if 569 there is no checkpoints for restoring in `directory`. An example usage is: 570 571 >>> import tempfile 572 573 >>> tmp_dir = tempfile.mkdtemp() 574 >>> checkpoint = tf.train.Checkpoint() 575 >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init')) 576 577 >>> def init_fn(): 578 ... # Partially restore the checkpoint from `init_path`. 579 ... checkpoint.restore(init_path) 580 581 >>> manager = tf.train.CheckpointManager( 582 ... checkpoint, 583 ... directory=os.path.join(tmp_dir, 'ckpt'), 584 ... max_to_keep=None, 585 ... init_fn=init_fn) 586 >>> # `restore_or_initialize` will call `init_fn` if there is no existing 587 >>> # checkpoint in `directory`. 588 >>> manager.restore_or_initialize() 589 590 Args: 591 checkpoint: The `tf.train.Checkpoint` instance to save and manage 592 checkpoints for. 593 directory: The path to a directory in which to write checkpoints. A 594 special file named "checkpoint" is also written to this directory (in a 595 human-readable text format) which contains the state of the 596 `CheckpointManager`. 597 max_to_keep: An integer, the number of checkpoints to keep. Unless 598 preserved by `keep_checkpoint_every_n_hours`, checkpoints will be 599 deleted from the active set, oldest first, until only `max_to_keep` 600 checkpoints remain. If `None`, no checkpoints are deleted and everything 601 stays in the active set. Note that `max_to_keep=None` will keep all 602 checkpoint paths in memory and in the checkpoint state protocol buffer 603 on disk. 604 keep_checkpoint_every_n_hours: Upon removal from the active set, a 605 checkpoint will be preserved if it has been at least 606 `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The 607 default setting of `None` does not preserve any checkpoints in this way. 608 checkpoint_name: Custom name for the checkpoint file. 609 step_counter: A `tf.Variable` instance for checking the current step 610 counter value, in case users want to save checkpoints every N steps. 611 checkpoint_interval: An integer, indicates the minimum step interval 612 between two checkpoints. 613 init_fn: Callable. A function to do customized intialization if no 614 checkpoints are in the directory. 615 616 Raises: 617 ValueError: If `max_to_keep` is not a positive integer. 618 """ 619 self._checkpoint = checkpoint 620 self._save_counter_assign = None 621 if max_to_keep is not None and max_to_keep <= 0: 622 raise ValueError( 623 ("Expected a positive integer or `None` for `max_to_keep`, " 624 "got %d.") 625 % (max_to_keep,)) 626 self._max_to_keep = max_to_keep 627 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 628 self._directory = directory 629 self._checkpoint_prefix = os.path.join(directory, checkpoint_name) 630 self._init_fn = init_fn 631 632 if checkpoint_interval is not None: 633 if step_counter is None: 634 raise ValueError("`step_counter` should be passed if " 635 "`checkpoint_interval` is not None.") 636 self._last_checkpoint_step = None 637 self._step_counter = step_counter 638 self._checkpoint_interval = checkpoint_interval 639 640 recovered_state = get_checkpoint_state(directory) 641 current_clock = time.time() 642 self._maybe_delete = collections.OrderedDict() 643 if recovered_state is None: 644 self._latest_checkpoint = None 645 # Set the clock back slightly to avoid race conditions when quickly 646 # re-creating a CheckpointManager. 647 self._last_preserved_timestamp = current_clock - 1. 648 else: 649 self._latest_checkpoint = recovered_state.model_checkpoint_path 650 self._last_preserved_timestamp = recovered_state.last_preserved_timestamp 651 if current_clock < self._last_preserved_timestamp: 652 # Time seems to have reversed itself. In addition to this warning, we'll 653 # min() saved checkpoint timestamps with the current time to ensure that 654 # old checkpoints don't get deleted accidentally. 655 logging.warning( 656 ("time.time() returned a value %f seconds behind the last " 657 "preserved checkpoint timestamp.") 658 % (self._last_preserved_timestamp - current_clock,)) 659 self._last_preserved_timestamp = current_clock 660 all_timestamps = recovered_state.all_model_checkpoint_timestamps 661 all_paths = recovered_state.all_model_checkpoint_paths 662 del recovered_state # Uses modified values from now on 663 if not all_timestamps: 664 all_timestamps = [self._last_preserved_timestamp] * len(all_paths) 665 666 for filename, timestamp in zip(all_paths, all_timestamps): 667 timestamp = min(timestamp, current_clock) 668 if timestamp > self._last_preserved_timestamp: 669 self._maybe_delete[filename] = timestamp 670 671 @property 672 def directory(self): 673 return self._directory 674 675 @property 676 def checkpoint_interval(self): 677 return self._checkpoint_interval 678 679 @property 680 def latest_checkpoint(self): 681 """The prefix of the most recent checkpoint in `directory`. 682 683 Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is 684 the constructor argument to `CheckpointManager`. 685 686 Suitable for passing to `tf.train.Checkpoint.restore` to resume training. 687 688 Returns: 689 The checkpoint prefix. If there are no checkpoints, returns `None`. 690 """ 691 return self._latest_checkpoint 692 693 @property 694 def checkpoints(self): 695 """A list of managed checkpoints. 696 697 Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not 698 show up in this list (to avoid ever-growing filename lists). 699 700 Returns: 701 A list of filenames, sorted from oldest to newest. 702 """ 703 return list(self._maybe_delete.keys()) 704 705 def _sweep(self): 706 """Deletes or preserves managed checkpoints.""" 707 if not self._max_to_keep: 708 # Does not update self._last_preserved_timestamp, since everything is kept 709 # in the active set. 710 return 711 while len(self._maybe_delete) > self._max_to_keep: 712 filename, timestamp = self._maybe_delete.popitem(last=False) 713 # Even if we're keeping this checkpoint due to 714 # keep_checkpoint_every_n_hours, we won't reference it to avoid 715 # infinitely-growing CheckpointState protos. 716 if (self._keep_checkpoint_every_n_hours 717 and (timestamp - self._keep_checkpoint_every_n_hours * 3600. 718 >= self._last_preserved_timestamp)): 719 self._last_preserved_timestamp = timestamp 720 continue 721 _delete_file_if_exists(filename + ".index") 722 _delete_file_if_exists(filename + ".data-?????-of-?????") 723 724 def _record_state(self): 725 """Saves the `CheckpointManager`'s state in `directory`.""" 726 filenames, timestamps = zip(*self._maybe_delete.items()) 727 update_checkpoint_state_internal( 728 self._directory, 729 model_checkpoint_path=self.latest_checkpoint, 730 all_model_checkpoint_paths=filenames, 731 all_model_checkpoint_timestamps=timestamps, 732 last_preserved_timestamp=self._last_preserved_timestamp, 733 save_relative_paths=True) 734 735 @property 736 def _prefix(self): 737 """A common prefix for all checkpoints saved with this manager. 738 739 For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, 740 `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be 741 numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each 742 checkpoint has several associated files 743 (e.g. `"/tmp/tf-model/ckpt-2.index"`). 744 745 Returns: 746 A string prefix. 747 """ 748 return self._checkpoint_prefix 749 750 @property 751 def checkpoint(self): 752 """Returns the `tf.train.Checkpoint` object.""" 753 return self._checkpoint 754 755 def save(self, checkpoint_number=None, check_interval=True, options=None): 756 """Creates a new checkpoint and manages it. 757 758 Args: 759 checkpoint_number: An optional integer, or an integer-dtype `Variable` or 760 `Tensor`, used to number the checkpoint. If `None` (default), 761 checkpoints are numbered using `checkpoint.save_counter`. Even if 762 `checkpoint_number` is provided, `save_counter` is still incremented. A 763 user-provided `checkpoint_number` is not incremented even if it is a 764 `Variable`. 765 check_interval: An optional boolean. The argument is only effective when 766 `checkpoint_interval` is passed into the manager. If `True`, the manager 767 will only save the checkpoint if the interval between checkpoints is 768 larger than `checkpoint_interval`. Otherwise it will always save the 769 checkpoint unless a checkpoint has already been saved for the current 770 step. 771 options: Optional `tf.train.CheckpointOptions` object. This argument only 772 works with TF2 checkpoint objects. For example, options = 773 tf.saved_model.SaveOptions(experimental_io_device='/job:localhost') 774 775 Returns: 776 The path to the new checkpoint. It is also recorded in the `checkpoints` 777 and `latest_checkpoint` properties. `None` if no checkpoint is saved. 778 """ 779 if self._checkpoint_interval is not None: 780 current_step = _evaluate(self._step_counter) 781 if self._last_checkpoint_step is not None: 782 if current_step == self._last_checkpoint_step: 783 return None 784 if check_interval and current_step < ( 785 self._last_checkpoint_step + self._checkpoint_interval): 786 return None 787 self._last_checkpoint_step = current_step 788 789 # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge 790 # slightly with a custom numbering option. 791 if context.executing_eagerly(): 792 save_counter = self._checkpoint.save_counter 793 save_counter.assign_add(1) 794 session = None 795 else: 796 session = ops.get_default_session() 797 798 def _initializing_creator(next_creator, **kwargs): 799 """Initialize the save counter if it has been newly created.""" 800 v = next_creator(**kwargs) 801 session.run(v.initializer) 802 return v 803 804 with variable_scope.variable_creator_scope(_initializing_creator): 805 save_counter = self._checkpoint.save_counter 806 if self._save_counter_assign is None: 807 self._save_counter_assign = save_counter.assign_add(1, read_value=False) 808 session.run(self._save_counter_assign) 809 if checkpoint_number is None: 810 checkpoint_number = save_counter 811 if not isinstance(checkpoint_number, compat.integral_types): 812 checkpoint_number = training_util.global_step( 813 sess=session, global_step_tensor=checkpoint_number) 814 prefix = "%s-%d" % (self._prefix, checkpoint_number) 815 if options is None: 816 save_path = self._checkpoint.write(prefix) 817 else: 818 save_path = self._checkpoint.write(prefix, options=options) 819 timestamp = time.time() 820 # If this is an overwritten checkpoint we were previously tracking, delete 821 # and reinsert it to make sure it goes to the end of the queue. 822 if save_path in self._maybe_delete: 823 del self._maybe_delete[save_path] 824 self._maybe_delete[save_path] = timestamp 825 self._latest_checkpoint = save_path 826 # Before deleting anything we update the Checkpoint proto with the new 827 # checkpoint. We'll go back and correct it after cleaning up old files, but 828 # a preemption while deleting will be more likely to see the new checkpoint 829 # this way. 830 self._record_state() 831 self._sweep() 832 # Write out the Checkpoint proto a second time, now without the deleted 833 # checkpoints. 834 self._record_state() 835 return save_path 836 837 def restore_or_initialize(self): 838 """Restore items in `checkpoint` from the latest checkpoint file. 839 840 This method will first try to restore from the most recent checkpoint in 841 `directory`. If no checkpoints exist in `directory`, and `init_fn` is 842 specified, this method will call `init_fn` to do customized 843 initialization. This can be used to support initialization from pretrained 844 models. 845 846 Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return 847 a load status object that users can run assertions on 848 (e.g. assert_consumed()). Thus to run assertions, users should directly use 849 `tf.train.Checkpoint.restore()` method. 850 851 Returns: 852 The restored checkpoint path if the lastest checkpoint is found and 853 restored. Otherwise None. 854 """ 855 if self._latest_checkpoint is not None: 856 self._checkpoint.restore(self._latest_checkpoint) 857 if self._checkpoint_interval is not None: 858 self._last_checkpoint_step = _evaluate(self._step_counter) 859 return self._latest_checkpoint 860 861 if self._init_fn is not None: 862 self._init_fn() 863 return None 864