1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import os.path 24import re 25import time 26 27from google.protobuf import text_format 28 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.eager import context 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.lib.io import file_io 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.training import training_util 37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 38from tensorflow.python.util import compat 39from tensorflow.python.util import deprecation 40from tensorflow.python.util.tf_export import tf_export 41 42 43def _GetCheckpointFilename(save_dir, latest_filename): 44 """Returns a filename for storing the CheckpointState. 45 46 Args: 47 save_dir: The directory for saving and restoring checkpoints. 48 latest_filename: Name of the file in 'save_dir' that is used 49 to store the CheckpointState. 50 51 Returns: 52 The path of the file that contains the CheckpointState proto. 53 """ 54 if latest_filename is None: 55 latest_filename = "checkpoint" 56 return os.path.join(save_dir, latest_filename) 57 58 59@tf_export(v1=["train.generate_checkpoint_state_proto"]) 60def generate_checkpoint_state_proto(save_dir, 61 model_checkpoint_path, 62 all_model_checkpoint_paths=None, 63 all_model_checkpoint_timestamps=None, 64 last_preserved_timestamp=None): 65 """Generates a checkpoint state proto. 66 67 Args: 68 save_dir: Directory where the model was saved. 69 model_checkpoint_path: The checkpoint file. 70 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 71 checkpoints, sorted from oldest to newest. If this is a non-empty list, 72 the last element must be equal to model_checkpoint_path. These paths 73 are also saved in the CheckpointState proto. 74 all_model_checkpoint_timestamps: A list of floats, indicating the number of 75 seconds since the Epoch when each checkpoint was generated. 76 last_preserved_timestamp: A float, indicating the number of seconds since 77 the Epoch when the last preserved checkpoint was written, e.g. due to a 78 `keep_checkpoint_every_n_hours` parameter (see 79 `tf.contrib.checkpoint.CheckpointManager` for an implementation). 80 Returns: 81 CheckpointState proto with model_checkpoint_path and 82 all_model_checkpoint_paths updated to either absolute paths or 83 relative paths to the current save_dir. 84 85 Raises: 86 ValueError: If `all_model_checkpoint_timestamps` was provided but its length 87 does not match `all_model_checkpoint_paths`. 88 """ 89 if all_model_checkpoint_paths is None: 90 all_model_checkpoint_paths = [] 91 92 if (not all_model_checkpoint_paths or 93 all_model_checkpoint_paths[-1] != model_checkpoint_path): 94 logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", 95 model_checkpoint_path) 96 all_model_checkpoint_paths.append(model_checkpoint_path) 97 98 if (all_model_checkpoint_timestamps 99 and (len(all_model_checkpoint_timestamps) 100 != len(all_model_checkpoint_paths))): 101 raise ValueError( 102 ("Checkpoint timestamps, if provided, must match checkpoint paths (got " 103 "paths %s and timestamps %s)") 104 % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) 105 106 # Relative paths need to be rewritten to be relative to the "save_dir" 107 # if model_checkpoint_path already contains "save_dir". 108 if not os.path.isabs(save_dir): 109 if not os.path.isabs(model_checkpoint_path): 110 model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) 111 for i in range(len(all_model_checkpoint_paths)): 112 p = all_model_checkpoint_paths[i] 113 if not os.path.isabs(p): 114 all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) 115 116 coord_checkpoint_proto = CheckpointState( 117 model_checkpoint_path=model_checkpoint_path, 118 all_model_checkpoint_paths=all_model_checkpoint_paths, 119 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 120 last_preserved_timestamp=last_preserved_timestamp) 121 122 return coord_checkpoint_proto 123 124 125@deprecation.deprecated( 126 date=None, 127 instructions=("Use `tf.train.CheckpointManager` to manage checkpoints " 128 "rather than manually editing the Checkpoint proto.")) 129@tf_export(v1=["train.update_checkpoint_state"]) 130def update_checkpoint_state(save_dir, 131 model_checkpoint_path, 132 all_model_checkpoint_paths=None, 133 latest_filename=None, 134 all_model_checkpoint_timestamps=None, 135 last_preserved_timestamp=None): 136 """Updates the content of the 'checkpoint' file. 137 138 This updates the checkpoint file containing a CheckpointState 139 proto. 140 141 Args: 142 save_dir: Directory where the model was saved. 143 model_checkpoint_path: The checkpoint file. 144 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 145 checkpoints, sorted from oldest to newest. If this is a non-empty list, 146 the last element must be equal to model_checkpoint_path. These paths 147 are also saved in the CheckpointState proto. 148 latest_filename: Optional name of the checkpoint file. Default to 149 'checkpoint'. 150 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 151 seconds since the Epoch) indicating when the checkpoints in 152 `all_model_checkpoint_paths` were created. 153 last_preserved_timestamp: A float, indicating the number of seconds since 154 the Epoch when the last preserved checkpoint was written, e.g. due to a 155 `keep_checkpoint_every_n_hours` parameter (see 156 `tf.contrib.checkpoint.CheckpointManager` for an implementation). 157 Raises: 158 RuntimeError: If any of the model checkpoint paths conflict with the file 159 containing CheckpointSate. 160 """ 161 update_checkpoint_state_internal( 162 save_dir=save_dir, 163 model_checkpoint_path=model_checkpoint_path, 164 all_model_checkpoint_paths=all_model_checkpoint_paths, 165 latest_filename=latest_filename, 166 save_relative_paths=False, 167 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 168 last_preserved_timestamp=last_preserved_timestamp) 169 170 171def update_checkpoint_state_internal(save_dir, 172 model_checkpoint_path, 173 all_model_checkpoint_paths=None, 174 latest_filename=None, 175 save_relative_paths=False, 176 all_model_checkpoint_timestamps=None, 177 last_preserved_timestamp=None): 178 """Updates the content of the 'checkpoint' file. 179 180 This updates the checkpoint file containing a CheckpointState 181 proto. 182 183 Args: 184 save_dir: Directory where the model was saved. 185 model_checkpoint_path: The checkpoint file. 186 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 187 checkpoints, sorted from oldest to newest. If this is a non-empty list, 188 the last element must be equal to model_checkpoint_path. These paths 189 are also saved in the CheckpointState proto. 190 latest_filename: Optional name of the checkpoint file. Default to 191 'checkpoint'. 192 save_relative_paths: If `True`, will write relative paths to the checkpoint 193 state file. 194 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 195 seconds since the Epoch) indicating when the checkpoints in 196 `all_model_checkpoint_paths` were created. 197 last_preserved_timestamp: A float, indicating the number of seconds since 198 the Epoch when the last preserved checkpoint was written, e.g. due to a 199 `keep_checkpoint_every_n_hours` parameter (see 200 `tf.contrib.checkpoint.CheckpointManager` for an implementation). 201 202 Raises: 203 RuntimeError: If any of the model checkpoint paths conflict with the file 204 containing CheckpointSate. 205 """ 206 # Writes the "checkpoint" file for the coordinator for later restoration. 207 coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) 208 if save_relative_paths: 209 if os.path.isabs(model_checkpoint_path): 210 rel_model_checkpoint_path = os.path.relpath( 211 model_checkpoint_path, save_dir) 212 else: 213 rel_model_checkpoint_path = model_checkpoint_path 214 rel_all_model_checkpoint_paths = [] 215 for p in all_model_checkpoint_paths: 216 if os.path.isabs(p): 217 rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) 218 else: 219 rel_all_model_checkpoint_paths.append(p) 220 ckpt = generate_checkpoint_state_proto( 221 save_dir, 222 rel_model_checkpoint_path, 223 all_model_checkpoint_paths=rel_all_model_checkpoint_paths, 224 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 225 last_preserved_timestamp=last_preserved_timestamp) 226 else: 227 ckpt = generate_checkpoint_state_proto( 228 save_dir, 229 model_checkpoint_path, 230 all_model_checkpoint_paths=all_model_checkpoint_paths, 231 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 232 last_preserved_timestamp=last_preserved_timestamp) 233 234 if coord_checkpoint_filename == ckpt.model_checkpoint_path: 235 raise RuntimeError("Save path '%s' conflicts with path used for " 236 "checkpoint state. Please use a different save path." % 237 model_checkpoint_path) 238 239 # Preventing potential read/write race condition by *atomically* writing to a 240 # file. 241 file_io.atomic_write_string_to_file(coord_checkpoint_filename, 242 text_format.MessageToString(ckpt)) 243 244 245@tf_export("train.get_checkpoint_state") 246def get_checkpoint_state(checkpoint_dir, latest_filename=None): 247 """Returns CheckpointState proto from the "checkpoint" file. 248 249 If the "checkpoint" file contains a valid CheckpointState 250 proto, returns it. 251 252 Args: 253 checkpoint_dir: The directory of checkpoints. 254 latest_filename: Optional name of the checkpoint file. Default to 255 'checkpoint'. 256 257 Returns: 258 A CheckpointState if the state was available, None 259 otherwise. 260 261 Raises: 262 ValueError: if the checkpoint read doesn't have model_checkpoint_path set. 263 """ 264 ckpt = None 265 coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, 266 latest_filename) 267 f = None 268 try: 269 # Check that the file exists before opening it to avoid 270 # many lines of errors from colossus in the logs. 271 if file_io.file_exists(coord_checkpoint_filename): 272 file_content = file_io.read_file_to_string( 273 coord_checkpoint_filename) 274 ckpt = CheckpointState() 275 text_format.Merge(file_content, ckpt) 276 if not ckpt.model_checkpoint_path: 277 raise ValueError("Invalid checkpoint state loaded from " 278 + checkpoint_dir) 279 # For relative model_checkpoint_path and all_model_checkpoint_paths, 280 # prepend checkpoint_dir. 281 if not os.path.isabs(ckpt.model_checkpoint_path): 282 ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, 283 ckpt.model_checkpoint_path) 284 for i in range(len(ckpt.all_model_checkpoint_paths)): 285 p = ckpt.all_model_checkpoint_paths[i] 286 if not os.path.isabs(p): 287 ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) 288 except errors.OpError as e: 289 # It's ok if the file cannot be read 290 logging.warning("%s: %s", type(e).__name__, e) 291 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 292 return None 293 except text_format.ParseError as e: 294 logging.warning("%s: %s", type(e).__name__, e) 295 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 296 return None 297 finally: 298 if f: 299 f.close() 300 return ckpt 301 302 303def _prefix_to_checkpoint_path(prefix, format_version): 304 """Returns the pathname of a checkpoint file, given the checkpoint prefix. 305 306 For V1 checkpoint, simply returns the prefix itself (the data file). For V2, 307 returns the pathname to the index file. 308 309 Args: 310 prefix: a string, the prefix of a checkpoint. 311 format_version: the checkpoint format version that corresponds to the 312 prefix. 313 Returns: 314 The pathname of a checkpoint file, taking into account the checkpoint 315 format version. 316 """ 317 if format_version == saver_pb2.SaverDef.V2: 318 return prefix + ".index" # The index file identifies a checkpoint. 319 return prefix # Just the data file. 320 321 322@tf_export("train.latest_checkpoint") 323def latest_checkpoint(checkpoint_dir, latest_filename=None): 324 """Finds the filename of latest saved checkpoint file. 325 326 Args: 327 checkpoint_dir: Directory where the variables were saved. 328 latest_filename: Optional name for the protocol buffer file that 329 contains the list of most recent checkpoint filenames. 330 See the corresponding argument to `Saver.save()`. 331 332 Returns: 333 The full path to the latest checkpoint or `None` if no checkpoint was found. 334 """ 335 # Pick the latest checkpoint based on checkpoint state. 336 ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) 337 if ckpt and ckpt.model_checkpoint_path: 338 # Look for either a V2 path or a V1 path, with priority for V2. 339 v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 340 saver_pb2.SaverDef.V2) 341 v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 342 saver_pb2.SaverDef.V1) 343 if file_io.get_matching_files(v2_path) or file_io.get_matching_files( 344 v1_path): 345 return ckpt.model_checkpoint_path 346 else: 347 logging.error("Couldn't match files for checkpoint %s", 348 ckpt.model_checkpoint_path) 349 return None 350 351 352@deprecation.deprecated( 353 date=None, 354 instructions="Use standard file APIs to check for files with this prefix.") 355@tf_export(v1=["train.checkpoint_exists"]) 356def checkpoint_exists(checkpoint_prefix): 357 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 358 359 This is the recommended way to check if a checkpoint exists, since it takes 360 into account the naming difference between V1 and V2 formats. 361 362 Args: 363 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 364 priority. Typically the result of `Saver.save()` or that of 365 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 366 V1/V2. 367 Returns: 368 A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. 369 """ 370 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 371 saver_pb2.SaverDef.V2) 372 if file_io.get_matching_files(pathname): 373 return True 374 elif file_io.get_matching_files(checkpoint_prefix): 375 return True 376 else: 377 return False 378 379 380@deprecation.deprecated( 381 date=None, 382 instructions="Use standard file utilities to get mtimes.") 383@tf_export(v1=["train.get_checkpoint_mtimes"]) 384def get_checkpoint_mtimes(checkpoint_prefixes): 385 """Returns the mtimes (modification timestamps) of the checkpoints. 386 387 Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files 388 exist, collect their mtime. Both V2 and V1 checkpoints are considered, in 389 that priority. 390 391 This is the recommended way to get the mtimes, since it takes into account 392 the naming difference between V1 and V2 formats. 393 394 Args: 395 checkpoint_prefixes: a list of checkpoint paths, typically the results of 396 `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of 397 sharded/non-sharded or V1/V2. 398 Returns: 399 A list of mtimes (in microseconds) of the found checkpoints. 400 """ 401 mtimes = [] 402 403 def match_maybe_append(pathname): 404 fnames = file_io.get_matching_files(pathname) 405 if fnames: 406 mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) 407 return True 408 return False 409 410 for checkpoint_prefix in checkpoint_prefixes: 411 # Tries V2's metadata file first. 412 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 413 saver_pb2.SaverDef.V2) 414 if match_maybe_append(pathname): 415 continue 416 # Otherwise, tries V1, where the prefix is the complete pathname. 417 match_maybe_append(checkpoint_prefix) 418 419 return mtimes 420 421 422@deprecation.deprecated( 423 date=None, 424 instructions="Use standard file APIs to delete files with this prefix.") 425@tf_export(v1=["train.remove_checkpoint"]) 426def remove_checkpoint(checkpoint_prefix, 427 checkpoint_format_version=saver_pb2.SaverDef.V2, 428 meta_graph_suffix="meta"): 429 """Removes a checkpoint given by `checkpoint_prefix`. 430 431 Args: 432 checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result 433 of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of 434 sharded/non-sharded or V1/V2. 435 checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to 436 `SaverDef.V2`. 437 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 438 """ 439 _delete_file_if_exists( 440 meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) 441 if checkpoint_format_version == saver_pb2.SaverDef.V2: 442 # V2 has a metadata file and some data files. 443 _delete_file_if_exists(checkpoint_prefix + ".index") 444 _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") 445 else: 446 # V1, Legacy. Exact match on the data file. 447 _delete_file_if_exists(checkpoint_prefix) 448 449 450def _delete_file_if_exists(filespec): 451 """Deletes files matching `filespec`.""" 452 for pathname in file_io.get_matching_files(filespec): 453 file_io.delete_file(pathname) 454 455 456def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): 457 """Returns the meta graph filename. 458 459 Args: 460 checkpoint_filename: Name of the checkpoint file. 461 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 462 463 Returns: 464 MetaGraph file name. 465 """ 466 # If the checkpoint_filename is sharded, the checkpoint_filename could 467 # be of format model.ckpt-step#-?????-of-shard#. For example, 468 # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. 469 basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) 470 suffixed_filename = ".".join([basename, meta_graph_suffix]) 471 return suffixed_filename 472 473 474# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? 475@tf_export("train.CheckpointManager") 476class CheckpointManager(object): 477 """Deletes old checkpoints. 478 479 Example usage: 480 ```python 481 import tensorflow as tf 482 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 483 manager = tf.contrib.checkpoint.CheckpointManager( 484 checkpoint, directory="/tmp/model", max_to_keep=5) 485 status = checkpoint.restore(manager.latest_checkpoint) 486 while True: 487 # train 488 manager.save() 489 ``` 490 491 `CheckpointManager` preserves its own state across instantiations (see the 492 `__init__` documentation for details). Only one should be active in a 493 particular directory at a time. 494 """ 495 496 def __init__(self, checkpoint, directory, 497 max_to_keep, keep_checkpoint_every_n_hours=None): 498 """Configure a `CheckpointManager` for use in `directory`. 499 500 If a `CheckpointManager` was previously used in `directory`, its 501 state will be restored. This includes the list of managed checkpoints and 502 the timestamp bookkeeping necessary to support 503 `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` 504 will be the same as the previous `CheckpointManager`, including cleaning up 505 existing checkpoints if appropriate. 506 507 Checkpoints are only considered for deletion just after a new checkpoint has 508 been added. At that point, `max_to_keep` checkpoints will remain in an 509 "active set". Once a checkpoint is preserved by 510 `keep_checkpoint_every_n_hours` it will not be deleted by this 511 `CheckpointManager` or any future `CheckpointManager` instantiated in 512 `directory` (regardless of the new setting of 513 `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the 514 active set may be deleted by this `CheckpointManager` or a future 515 `CheckpointManager` instantiated in `directory` (subject to its 516 `max_to_keep` and `keep_checkpoint_every_n_hours` settings). 517 518 Args: 519 checkpoint: The `tf.train.Checkpoint` instance to save and manage 520 checkpoints for. 521 directory: The path to a directory in which to write checkpoints. A 522 special file named "checkpoint" is also written to this directory (in a 523 human-readable text format) which contains the state of the 524 `CheckpointManager`. 525 max_to_keep: An integer, the number of checkpoints to keep. Unless 526 preserved by `keep_checkpoint_every_n_hours`, checkpoints will be 527 deleted from the active set, oldest first, until only `max_to_keep` 528 checkpoints remain. If `None`, no checkpoints are deleted and everything 529 stays in the active set. Note that `max_to_keep=None` will keep all 530 checkpoint paths in memory and in the checkpoint state protocol buffer 531 on disk. 532 keep_checkpoint_every_n_hours: Upon removal from the active set, a 533 checkpoint will be preserved if it has been at least 534 `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The 535 default setting of `None` does not preserve any checkpoints in this way. 536 537 Raises: 538 ValueError: If `max_to_keep` is not a positive integer. 539 """ 540 self._checkpoint = checkpoint 541 self._save_counter_assign = None 542 if max_to_keep is not None and max_to_keep <= 0: 543 raise ValueError( 544 ("Expected a positive integer or `None` for `max_to_max_to_keep`, " 545 "got %d.") 546 % (max_to_keep,)) 547 self._max_to_keep = max_to_keep 548 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 549 self._directory = directory 550 self._checkpoint_prefix = os.path.join(directory, "ckpt") 551 recovered_state = get_checkpoint_state(directory) 552 current_clock = time.time() 553 self._maybe_delete = collections.OrderedDict() 554 if recovered_state is None: 555 self._latest_checkpoint = None 556 # Set the clock back slightly to avoid race conditions when quckly 557 # re-creating a CheckpointManager. 558 self._last_preserved_timestamp = current_clock - 1. 559 else: 560 self._latest_checkpoint = recovered_state.model_checkpoint_path 561 self._last_preserved_timestamp = recovered_state.last_preserved_timestamp 562 if current_clock < self._last_preserved_timestamp: 563 # Time seems to have reversed itself. In addition to this warning, we'll 564 # min() saved checkpoint timestamps with the current time to ensure that 565 # old checkpoints don't get deleted accidentally. 566 logging.warning( 567 ("time.time() returned a value %f seconds behind the last " 568 "preserved checkpoint timestamp.") 569 % (self._last_preserved_timestamp - current_clock,)) 570 self._last_preserved_timestamp = current_clock 571 all_timestamps = recovered_state.all_model_checkpoint_timestamps 572 all_paths = recovered_state.all_model_checkpoint_paths 573 del recovered_state # Uses modified values from now on 574 if not all_timestamps: 575 all_timestamps = [self._last_preserved_timestamp] * len(all_paths) 576 577 for filename, timestamp in zip(all_paths, all_timestamps): 578 timestamp = min(timestamp, current_clock) 579 if timestamp > self._last_preserved_timestamp: 580 self._maybe_delete[filename] = timestamp 581 582 @property 583 def latest_checkpoint(self): 584 """The prefix of the most recent checkpoint in `directory`. 585 586 Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is 587 the constructor argument to `CheckpointManager`. 588 589 Suitable for passing to `tf.train.Checkpoint.restore` to resume training. 590 591 Returns: 592 The checkpoint prefix. If there are no checkpoints, returns `None`. 593 """ 594 return self._latest_checkpoint 595 596 @property 597 def checkpoints(self): 598 """A list of managed checkpoints. 599 600 Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not 601 show up in this list (to avoid ever-growing filename lists). 602 603 Returns: 604 A list of filenames, sorted from oldest to newest. 605 """ 606 return list(self._maybe_delete.keys()) 607 608 def _sweep(self): 609 """Deletes or preserves managed checkpoints.""" 610 if not self._max_to_keep: 611 # Does not update self._last_preserved_timestamp, since everything is kept 612 # in the active set. 613 return 614 while len(self._maybe_delete) > self._max_to_keep: 615 filename, timestamp = self._maybe_delete.popitem(last=False) 616 # Even if we're keeping this checkpoint due to 617 # keep_checkpoint_every_n_hours, we won't reference it to avoid 618 # infinitely-growing CheckpointState protos. 619 if (self._keep_checkpoint_every_n_hours 620 and (timestamp - self._keep_checkpoint_every_n_hours * 3600. 621 >= self._last_preserved_timestamp)): 622 self._last_preserved_timestamp = timestamp 623 continue 624 _delete_file_if_exists(filename + ".index") 625 _delete_file_if_exists(filename + ".data-?????-of-?????") 626 627 def _record_state(self): 628 """Saves the `CheckpointManager`'s state in `directory`.""" 629 filenames, timestamps = zip(*self._maybe_delete.items()) 630 update_checkpoint_state_internal( 631 self._directory, 632 model_checkpoint_path=self.latest_checkpoint, 633 all_model_checkpoint_paths=filenames, 634 all_model_checkpoint_timestamps=timestamps, 635 last_preserved_timestamp=self._last_preserved_timestamp, 636 save_relative_paths=True) 637 638 @property 639 def _prefix(self): 640 """A common prefix for all checkpoints saved with this manager. 641 642 For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, 643 `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be 644 numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each 645 checkpoint has several associated files 646 (e.g. `"/tmp/tf-model/ckpt-2.index"`). 647 648 Returns: 649 A string prefix. 650 """ 651 return self._checkpoint_prefix 652 653 def save(self, checkpoint_number=None): 654 """Creates a new checkpoint and manages it. 655 656 Args: 657 checkpoint_number: An optional integer, or an integer-dtype `Variable` or 658 `Tensor`, used to number the checkpoint. If `None` (default), 659 checkpoints are numbered using `checkpoint.save_counter`. Even if 660 `checkpoint_number` is provided, `save_counter` is still incremented. A 661 user-provided `checkpoint_number` is not incremented even if it is a 662 `Variable`. 663 664 Returns: 665 The path to the new checkpoint. It is also recorded in the `checkpoints` 666 and `latest_checkpoint` properies. 667 """ 668 # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge 669 # slightly with a custom numbering option. 670 if context.executing_eagerly(): 671 save_counter = self._checkpoint.save_counter 672 save_counter.assign_add(1) 673 session = None 674 else: 675 session = ops.get_default_session() 676 677 def _initializing_creator(next_creator, **kwargs): 678 """Initialize the save counter if it has been newly created.""" 679 v = next_creator(**kwargs) 680 session.run(v.initializer) 681 return v 682 683 with variable_scope.variable_creator_scope(_initializing_creator): 684 save_counter = self._checkpoint.save_counter 685 if self._save_counter_assign is None: 686 self._save_counter_assign = save_counter.assign_add(1, read_value=False) 687 session.run(self._save_counter_assign) 688 if checkpoint_number is None: 689 checkpoint_number = save_counter 690 if not isinstance(checkpoint_number, compat.integral_types): 691 checkpoint_number = training_util.global_step( 692 sess=session, global_step_tensor=checkpoint_number) 693 prefix = "%s-%d" % (self._prefix, checkpoint_number) 694 save_path = self._checkpoint.write(prefix) 695 timestamp = time.time() 696 # If this is an overwritten checkpoint we were previously tracking, delete 697 # and reinsert it to make sure it goes to the end of the queue. 698 if save_path in self._maybe_delete: 699 del self._maybe_delete[save_path] 700 self._maybe_delete[save_path] = timestamp 701 self._latest_checkpoint = save_path 702 self._sweep() 703 self._record_state() 704 return save_path 705