1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import os.path 24import re 25import time 26 27from google.protobuf import text_format 28 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.eager import context 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.lib.io import file_io 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.training import training_util 37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 38from tensorflow.python.util import compat 39from tensorflow.python.util import deprecation 40from tensorflow.python.util.tf_export import tf_export 41 42 43def _GetCheckpointFilename(save_dir, latest_filename): 44 """Returns a filename for storing the CheckpointState. 45 46 Args: 47 save_dir: The directory for saving and restoring checkpoints. 48 latest_filename: Name of the file in 'save_dir' that is used 49 to store the CheckpointState. 50 51 Returns: 52 The path of the file that contains the CheckpointState proto. 53 """ 54 if latest_filename is None: 55 latest_filename = "checkpoint" 56 return os.path.join(save_dir, latest_filename) 57 58 59@tf_export(v1=["train.generate_checkpoint_state_proto"]) 60def generate_checkpoint_state_proto(save_dir, 61 model_checkpoint_path, 62 all_model_checkpoint_paths=None, 63 all_model_checkpoint_timestamps=None, 64 last_preserved_timestamp=None): 65 """Generates a checkpoint state proto. 66 67 Args: 68 save_dir: Directory where the model was saved. 69 model_checkpoint_path: The checkpoint file. 70 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 71 checkpoints, sorted from oldest to newest. If this is a non-empty list, 72 the last element must be equal to model_checkpoint_path. These paths 73 are also saved in the CheckpointState proto. 74 all_model_checkpoint_timestamps: A list of floats, indicating the number of 75 seconds since the Epoch when each checkpoint was generated. 76 last_preserved_timestamp: A float, indicating the number of seconds since 77 the Epoch when the last preserved checkpoint was written, e.g. due to a 78 `keep_checkpoint_every_n_hours` parameter (see 79 `tf.train.CheckpointManager` for an implementation). 80 Returns: 81 CheckpointState proto with model_checkpoint_path and 82 all_model_checkpoint_paths updated to either absolute paths or 83 relative paths to the current save_dir. 84 85 Raises: 86 ValueError: If `all_model_checkpoint_timestamps` was provided but its length 87 does not match `all_model_checkpoint_paths`. 88 """ 89 if all_model_checkpoint_paths is None: 90 all_model_checkpoint_paths = [] 91 92 if (not all_model_checkpoint_paths or 93 all_model_checkpoint_paths[-1] != model_checkpoint_path): 94 logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", 95 model_checkpoint_path) 96 all_model_checkpoint_paths.append(model_checkpoint_path) 97 98 if (all_model_checkpoint_timestamps 99 and (len(all_model_checkpoint_timestamps) 100 != len(all_model_checkpoint_paths))): 101 raise ValueError( 102 ("Checkpoint timestamps, if provided, must match checkpoint paths (got " 103 "paths %s and timestamps %s)") 104 % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) 105 106 # Relative paths need to be rewritten to be relative to the "save_dir" 107 # if model_checkpoint_path already contains "save_dir". 108 if not os.path.isabs(save_dir): 109 if not os.path.isabs(model_checkpoint_path): 110 model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) 111 for i, p in enumerate(all_model_checkpoint_paths): 112 if not os.path.isabs(p): 113 all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) 114 115 coord_checkpoint_proto = CheckpointState( 116 model_checkpoint_path=model_checkpoint_path, 117 all_model_checkpoint_paths=all_model_checkpoint_paths, 118 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 119 last_preserved_timestamp=last_preserved_timestamp) 120 121 return coord_checkpoint_proto 122 123 124@deprecation.deprecated( 125 date=None, 126 instructions=("Use `tf.train.CheckpointManager` to manage checkpoints " 127 "rather than manually editing the Checkpoint proto.")) 128@tf_export(v1=["train.update_checkpoint_state"]) 129def update_checkpoint_state(save_dir, 130 model_checkpoint_path, 131 all_model_checkpoint_paths=None, 132 latest_filename=None, 133 all_model_checkpoint_timestamps=None, 134 last_preserved_timestamp=None): 135 """Updates the content of the 'checkpoint' file. 136 137 This updates the checkpoint file containing a CheckpointState 138 proto. 139 140 Args: 141 save_dir: Directory where the model was saved. 142 model_checkpoint_path: The checkpoint file. 143 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 144 checkpoints, sorted from oldest to newest. If this is a non-empty list, 145 the last element must be equal to model_checkpoint_path. These paths 146 are also saved in the CheckpointState proto. 147 latest_filename: Optional name of the checkpoint file. Default to 148 'checkpoint'. 149 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 150 seconds since the Epoch) indicating when the checkpoints in 151 `all_model_checkpoint_paths` were created. 152 last_preserved_timestamp: A float, indicating the number of seconds since 153 the Epoch when the last preserved checkpoint was written, e.g. due to a 154 `keep_checkpoint_every_n_hours` parameter (see 155 `tf.train.CheckpointManager` for an implementation). 156 Raises: 157 RuntimeError: If any of the model checkpoint paths conflict with the file 158 containing CheckpointSate. 159 """ 160 update_checkpoint_state_internal( 161 save_dir=save_dir, 162 model_checkpoint_path=model_checkpoint_path, 163 all_model_checkpoint_paths=all_model_checkpoint_paths, 164 latest_filename=latest_filename, 165 save_relative_paths=False, 166 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 167 last_preserved_timestamp=last_preserved_timestamp) 168 169 170def update_checkpoint_state_internal(save_dir, 171 model_checkpoint_path, 172 all_model_checkpoint_paths=None, 173 latest_filename=None, 174 save_relative_paths=False, 175 all_model_checkpoint_timestamps=None, 176 last_preserved_timestamp=None): 177 """Updates the content of the 'checkpoint' file. 178 179 This updates the checkpoint file containing a CheckpointState 180 proto. 181 182 Args: 183 save_dir: Directory where the model was saved. 184 model_checkpoint_path: The checkpoint file. 185 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 186 checkpoints, sorted from oldest to newest. If this is a non-empty list, 187 the last element must be equal to model_checkpoint_path. These paths 188 are also saved in the CheckpointState proto. 189 latest_filename: Optional name of the checkpoint file. Default to 190 'checkpoint'. 191 save_relative_paths: If `True`, will write relative paths to the checkpoint 192 state file. 193 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 194 seconds since the Epoch) indicating when the checkpoints in 195 `all_model_checkpoint_paths` were created. 196 last_preserved_timestamp: A float, indicating the number of seconds since 197 the Epoch when the last preserved checkpoint was written, e.g. due to a 198 `keep_checkpoint_every_n_hours` parameter (see 199 `tf.train.CheckpointManager` for an implementation). 200 201 Raises: 202 RuntimeError: If any of the model checkpoint paths conflict with the file 203 containing CheckpointSate. 204 """ 205 # Writes the "checkpoint" file for the coordinator for later restoration. 206 coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) 207 if save_relative_paths: 208 if os.path.isabs(model_checkpoint_path): 209 rel_model_checkpoint_path = os.path.relpath( 210 model_checkpoint_path, save_dir) 211 else: 212 rel_model_checkpoint_path = model_checkpoint_path 213 rel_all_model_checkpoint_paths = [] 214 for p in all_model_checkpoint_paths: 215 if os.path.isabs(p): 216 rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) 217 else: 218 rel_all_model_checkpoint_paths.append(p) 219 ckpt = generate_checkpoint_state_proto( 220 save_dir, 221 rel_model_checkpoint_path, 222 all_model_checkpoint_paths=rel_all_model_checkpoint_paths, 223 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 224 last_preserved_timestamp=last_preserved_timestamp) 225 else: 226 ckpt = generate_checkpoint_state_proto( 227 save_dir, 228 model_checkpoint_path, 229 all_model_checkpoint_paths=all_model_checkpoint_paths, 230 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 231 last_preserved_timestamp=last_preserved_timestamp) 232 233 if coord_checkpoint_filename == ckpt.model_checkpoint_path: 234 raise RuntimeError("Save path '%s' conflicts with path used for " 235 "checkpoint state. Please use a different save path." % 236 model_checkpoint_path) 237 238 # Preventing potential read/write race condition by *atomically* writing to a 239 # file. 240 file_io.atomic_write_string_to_file(coord_checkpoint_filename, 241 text_format.MessageToString(ckpt)) 242 243 244@tf_export("train.get_checkpoint_state") 245def get_checkpoint_state(checkpoint_dir, latest_filename=None): 246 """Returns CheckpointState proto from the "checkpoint" file. 247 248 If the "checkpoint" file contains a valid CheckpointState 249 proto, returns it. 250 251 Args: 252 checkpoint_dir: The directory of checkpoints. 253 latest_filename: Optional name of the checkpoint file. Default to 254 'checkpoint'. 255 256 Returns: 257 A CheckpointState if the state was available, None 258 otherwise. 259 260 Raises: 261 ValueError: if the checkpoint read doesn't have model_checkpoint_path set. 262 """ 263 ckpt = None 264 coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, 265 latest_filename) 266 f = None 267 try: 268 # Check that the file exists before opening it to avoid 269 # many lines of errors from colossus in the logs. 270 if file_io.file_exists(coord_checkpoint_filename): 271 file_content = file_io.read_file_to_string( 272 coord_checkpoint_filename) 273 ckpt = CheckpointState() 274 text_format.Merge(file_content, ckpt) 275 if not ckpt.model_checkpoint_path: 276 raise ValueError("Invalid checkpoint state loaded from " 277 + checkpoint_dir) 278 # For relative model_checkpoint_path and all_model_checkpoint_paths, 279 # prepend checkpoint_dir. 280 if not os.path.isabs(ckpt.model_checkpoint_path): 281 ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, 282 ckpt.model_checkpoint_path) 283 for i, p in enumerate(ckpt.all_model_checkpoint_paths): 284 if not os.path.isabs(p): 285 ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) 286 except errors.OpError as e: 287 # It's ok if the file cannot be read 288 logging.warning("%s: %s", type(e).__name__, e) 289 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 290 return None 291 except text_format.ParseError as e: 292 logging.warning("%s: %s", type(e).__name__, e) 293 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 294 return None 295 finally: 296 if f: 297 f.close() 298 return ckpt 299 300 301def _prefix_to_checkpoint_path(prefix, format_version): 302 """Returns the pathname of a checkpoint file, given the checkpoint prefix. 303 304 For V1 checkpoint, simply returns the prefix itself (the data file). For V2, 305 returns the pathname to the index file. 306 307 Args: 308 prefix: a string, the prefix of a checkpoint. 309 format_version: the checkpoint format version that corresponds to the 310 prefix. 311 Returns: 312 The pathname of a checkpoint file, taking into account the checkpoint 313 format version. 314 """ 315 if format_version == saver_pb2.SaverDef.V2: 316 return prefix + ".index" # The index file identifies a checkpoint. 317 return prefix # Just the data file. 318 319 320@tf_export("train.latest_checkpoint") 321def latest_checkpoint(checkpoint_dir, latest_filename=None): 322 """Finds the filename of latest saved checkpoint file. 323 324 Gets the checkpoint state given the provided checkpoint_dir and looks for a 325 corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path. 326 The latest_filename argument is only applicable if you are saving checkpoint 327 using `v1.Saver.save` 328 329 330 See the [Training Checkpoints 331 Guide](https://www.tensorflow.org/guide/checkpoint) for more details and 332 examples.` 333 334 Args: 335 checkpoint_dir: Directory where the variables were saved. 336 latest_filename: Optional name for the protocol buffer file that 337 contains the list of most recent checkpoint filenames. 338 See the corresponding argument to `v1.Saver.save`. 339 340 Returns: 341 The full path to the latest checkpoint or `None` if no checkpoint was found. 342 """ 343 # Pick the latest checkpoint based on checkpoint state. 344 ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) 345 if ckpt and ckpt.model_checkpoint_path: 346 # Look for either a V2 path or a V1 path, with priority for V2. 347 v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 348 saver_pb2.SaverDef.V2) 349 v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 350 saver_pb2.SaverDef.V1) 351 if file_io.get_matching_files(v2_path) or file_io.get_matching_files( 352 v1_path): 353 return ckpt.model_checkpoint_path 354 else: 355 logging.error("Couldn't match files for checkpoint %s", 356 ckpt.model_checkpoint_path) 357 return None 358 359 360def checkpoint_exists_internal(checkpoint_prefix): 361 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 362 363 This is an internal function to check if a checkpoint exists, 364 since it takes into account the naming difference between V1 and V2 formats. 365 366 Args: 367 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 368 priority. Typically the result of `Saver.save()` or that of 369 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 370 V1/V2. 371 Returns: 372 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 373 """ 374 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 375 saver_pb2.SaverDef.V2) 376 if file_io.get_matching_files(pathname): 377 return True 378 elif file_io.get_matching_files(checkpoint_prefix): 379 return True 380 else: 381 return False 382 383 384@deprecation.deprecated( 385 date=None, 386 instructions="Use standard file APIs to check for files with this prefix.") 387@tf_export(v1=["train.checkpoint_exists"]) 388def checkpoint_exists(checkpoint_prefix): 389 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 390 391 This is the recommended way to check if a checkpoint exists, since it takes 392 into account the naming difference between V1 and V2 formats. 393 394 Args: 395 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 396 priority. Typically the result of `Saver.save()` or that of 397 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 398 V1/V2. 399 400 Returns: 401 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 402 """ 403 return checkpoint_exists_internal(checkpoint_prefix) 404 405 406@deprecation.deprecated( 407 date=None, 408 instructions="Use standard file utilities to get mtimes.") 409@tf_export(v1=["train.get_checkpoint_mtimes"]) 410def get_checkpoint_mtimes(checkpoint_prefixes): 411 """Returns the mtimes (modification timestamps) of the checkpoints. 412 413 Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files 414 exist, collect their mtime. Both V2 and V1 checkpoints are considered, in 415 that priority. 416 417 This is the recommended way to get the mtimes, since it takes into account 418 the naming difference between V1 and V2 formats. 419 420 Note: If not all checkpoints exist, the length of the returned mtimes list 421 will be smaller than the length of `checkpoint_prefixes` list, so mapping 422 checkpoints to corresponding mtimes will not be possible. 423 424 Args: 425 checkpoint_prefixes: a list of checkpoint paths, typically the results of 426 `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of 427 sharded/non-sharded or V1/V2. 428 Returns: 429 A list of mtimes (in microseconds) of the found checkpoints. 430 """ 431 mtimes = [] 432 433 def match_maybe_append(pathname): 434 fnames = file_io.get_matching_files(pathname) 435 if fnames: 436 mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) 437 return True 438 return False 439 440 for checkpoint_prefix in checkpoint_prefixes: 441 # Tries V2's metadata file first. 442 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 443 saver_pb2.SaverDef.V2) 444 if match_maybe_append(pathname): 445 continue 446 # Otherwise, tries V1, where the prefix is the complete pathname. 447 match_maybe_append(checkpoint_prefix) 448 449 return mtimes 450 451 452@deprecation.deprecated( 453 date=None, 454 instructions="Use standard file APIs to delete files with this prefix.") 455@tf_export(v1=["train.remove_checkpoint"]) 456def remove_checkpoint(checkpoint_prefix, 457 checkpoint_format_version=saver_pb2.SaverDef.V2, 458 meta_graph_suffix="meta"): 459 """Removes a checkpoint given by `checkpoint_prefix`. 460 461 Args: 462 checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result 463 of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of 464 sharded/non-sharded or V1/V2. 465 checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to 466 `SaverDef.V2`. 467 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 468 """ 469 _delete_file_if_exists( 470 meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) 471 if checkpoint_format_version == saver_pb2.SaverDef.V2: 472 # V2 has a metadata file and some data files. 473 _delete_file_if_exists(checkpoint_prefix + ".index") 474 _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") 475 else: 476 # V1, Legacy. Exact match on the data file. 477 _delete_file_if_exists(checkpoint_prefix) 478 479 480def _delete_file_if_exists(filespec): 481 """Deletes files matching `filespec`.""" 482 for pathname in file_io.get_matching_files(filespec): 483 file_io.delete_file(pathname) 484 485 486def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): 487 """Returns the meta graph filename. 488 489 Args: 490 checkpoint_filename: Name of the checkpoint file. 491 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 492 493 Returns: 494 MetaGraph file name. 495 """ 496 # If the checkpoint_filename is sharded, the checkpoint_filename could 497 # be of format model.ckpt-step#-?????-of-shard#. For example, 498 # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. 499 basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) 500 suffixed_filename = ".".join([basename, meta_graph_suffix]) 501 return suffixed_filename 502 503 504# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? 505@tf_export("train.CheckpointManager") 506class CheckpointManager(object): 507 """Deletes old checkpoints. 508 509 Example usage: 510 511 ```python 512 import tensorflow as tf 513 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 514 manager = tf.train.CheckpointManager( 515 checkpoint, directory="/tmp/model", max_to_keep=5) 516 status = checkpoint.restore(manager.latest_checkpoint) 517 while True: 518 # train 519 manager.save() 520 ``` 521 522 `CheckpointManager` preserves its own state across instantiations (see the 523 `__init__` documentation for details). Only one should be active in a 524 particular directory at a time. 525 """ 526 527 def __init__(self, 528 checkpoint, 529 directory, 530 max_to_keep, 531 keep_checkpoint_every_n_hours=None, 532 checkpoint_name="ckpt"): 533 """Configure a `CheckpointManager` for use in `directory`. 534 535 If a `CheckpointManager` was previously used in `directory`, its 536 state will be restored. This includes the list of managed checkpoints and 537 the timestamp bookkeeping necessary to support 538 `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` 539 will be the same as the previous `CheckpointManager`, including cleaning up 540 existing checkpoints if appropriate. 541 542 Checkpoints are only considered for deletion just after a new checkpoint has 543 been added. At that point, `max_to_keep` checkpoints will remain in an 544 "active set". Once a checkpoint is preserved by 545 `keep_checkpoint_every_n_hours` it will not be deleted by this 546 `CheckpointManager` or any future `CheckpointManager` instantiated in 547 `directory` (regardless of the new setting of 548 `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the 549 active set may be deleted by this `CheckpointManager` or a future 550 `CheckpointManager` instantiated in `directory` (subject to its 551 `max_to_keep` and `keep_checkpoint_every_n_hours` settings). 552 553 Args: 554 checkpoint: The `tf.train.Checkpoint` instance to save and manage 555 checkpoints for. 556 directory: The path to a directory in which to write checkpoints. A 557 special file named "checkpoint" is also written to this directory (in a 558 human-readable text format) which contains the state of the 559 `CheckpointManager`. 560 max_to_keep: An integer, the number of checkpoints to keep. Unless 561 preserved by `keep_checkpoint_every_n_hours`, checkpoints will be 562 deleted from the active set, oldest first, until only `max_to_keep` 563 checkpoints remain. If `None`, no checkpoints are deleted and everything 564 stays in the active set. Note that `max_to_keep=None` will keep all 565 checkpoint paths in memory and in the checkpoint state protocol buffer 566 on disk. 567 keep_checkpoint_every_n_hours: Upon removal from the active set, a 568 checkpoint will be preserved if it has been at least 569 `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The 570 default setting of `None` does not preserve any checkpoints in this way. 571 checkpoint_name: Custom name for the checkpoint file. 572 573 Raises: 574 ValueError: If `max_to_keep` is not a positive integer. 575 """ 576 self._checkpoint = checkpoint 577 self._save_counter_assign = None 578 if max_to_keep is not None and max_to_keep <= 0: 579 raise ValueError( 580 ("Expected a positive integer or `None` for `max_to_keep`, " 581 "got %d.") 582 % (max_to_keep,)) 583 self._max_to_keep = max_to_keep 584 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 585 self._directory = directory 586 self._checkpoint_prefix = os.path.join(directory, checkpoint_name) 587 recovered_state = get_checkpoint_state(directory) 588 current_clock = time.time() 589 self._maybe_delete = collections.OrderedDict() 590 if recovered_state is None: 591 self._latest_checkpoint = None 592 # Set the clock back slightly to avoid race conditions when quckly 593 # re-creating a CheckpointManager. 594 self._last_preserved_timestamp = current_clock - 1. 595 else: 596 self._latest_checkpoint = recovered_state.model_checkpoint_path 597 self._last_preserved_timestamp = recovered_state.last_preserved_timestamp 598 if current_clock < self._last_preserved_timestamp: 599 # Time seems to have reversed itself. In addition to this warning, we'll 600 # min() saved checkpoint timestamps with the current time to ensure that 601 # old checkpoints don't get deleted accidentally. 602 logging.warning( 603 ("time.time() returned a value %f seconds behind the last " 604 "preserved checkpoint timestamp.") 605 % (self._last_preserved_timestamp - current_clock,)) 606 self._last_preserved_timestamp = current_clock 607 all_timestamps = recovered_state.all_model_checkpoint_timestamps 608 all_paths = recovered_state.all_model_checkpoint_paths 609 del recovered_state # Uses modified values from now on 610 if not all_timestamps: 611 all_timestamps = [self._last_preserved_timestamp] * len(all_paths) 612 613 for filename, timestamp in zip(all_paths, all_timestamps): 614 timestamp = min(timestamp, current_clock) 615 if timestamp > self._last_preserved_timestamp: 616 self._maybe_delete[filename] = timestamp 617 618 @property 619 def directory(self): 620 return self._directory 621 622 @property 623 def latest_checkpoint(self): 624 """The prefix of the most recent checkpoint in `directory`. 625 626 Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is 627 the constructor argument to `CheckpointManager`. 628 629 Suitable for passing to `tf.train.Checkpoint.restore` to resume training. 630 631 Returns: 632 The checkpoint prefix. If there are no checkpoints, returns `None`. 633 """ 634 return self._latest_checkpoint 635 636 @property 637 def checkpoints(self): 638 """A list of managed checkpoints. 639 640 Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not 641 show up in this list (to avoid ever-growing filename lists). 642 643 Returns: 644 A list of filenames, sorted from oldest to newest. 645 """ 646 return list(self._maybe_delete.keys()) 647 648 def _sweep(self): 649 """Deletes or preserves managed checkpoints.""" 650 if not self._max_to_keep: 651 # Does not update self._last_preserved_timestamp, since everything is kept 652 # in the active set. 653 return 654 while len(self._maybe_delete) > self._max_to_keep: 655 filename, timestamp = self._maybe_delete.popitem(last=False) 656 # Even if we're keeping this checkpoint due to 657 # keep_checkpoint_every_n_hours, we won't reference it to avoid 658 # infinitely-growing CheckpointState protos. 659 if (self._keep_checkpoint_every_n_hours 660 and (timestamp - self._keep_checkpoint_every_n_hours * 3600. 661 >= self._last_preserved_timestamp)): 662 self._last_preserved_timestamp = timestamp 663 continue 664 _delete_file_if_exists(filename + ".index") 665 _delete_file_if_exists(filename + ".data-?????-of-?????") 666 667 def _record_state(self): 668 """Saves the `CheckpointManager`'s state in `directory`.""" 669 filenames, timestamps = zip(*self._maybe_delete.items()) 670 update_checkpoint_state_internal( 671 self._directory, 672 model_checkpoint_path=self.latest_checkpoint, 673 all_model_checkpoint_paths=filenames, 674 all_model_checkpoint_timestamps=timestamps, 675 last_preserved_timestamp=self._last_preserved_timestamp, 676 save_relative_paths=True) 677 678 @property 679 def _prefix(self): 680 """A common prefix for all checkpoints saved with this manager. 681 682 For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, 683 `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be 684 numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each 685 checkpoint has several associated files 686 (e.g. `"/tmp/tf-model/ckpt-2.index"`). 687 688 Returns: 689 A string prefix. 690 """ 691 return self._checkpoint_prefix 692 693 def save(self, checkpoint_number=None): 694 """Creates a new checkpoint and manages it. 695 696 Args: 697 checkpoint_number: An optional integer, or an integer-dtype `Variable` or 698 `Tensor`, used to number the checkpoint. If `None` (default), 699 checkpoints are numbered using `checkpoint.save_counter`. Even if 700 `checkpoint_number` is provided, `save_counter` is still incremented. A 701 user-provided `checkpoint_number` is not incremented even if it is a 702 `Variable`. 703 704 Returns: 705 The path to the new checkpoint. It is also recorded in the `checkpoints` 706 and `latest_checkpoint` properties. 707 """ 708 # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge 709 # slightly with a custom numbering option. 710 if context.executing_eagerly(): 711 save_counter = self._checkpoint.save_counter 712 save_counter.assign_add(1) 713 session = None 714 else: 715 session = ops.get_default_session() 716 717 def _initializing_creator(next_creator, **kwargs): 718 """Initialize the save counter if it has been newly created.""" 719 v = next_creator(**kwargs) 720 session.run(v.initializer) 721 return v 722 723 with variable_scope.variable_creator_scope(_initializing_creator): 724 save_counter = self._checkpoint.save_counter 725 if self._save_counter_assign is None: 726 self._save_counter_assign = save_counter.assign_add(1, read_value=False) 727 session.run(self._save_counter_assign) 728 if checkpoint_number is None: 729 checkpoint_number = save_counter 730 if not isinstance(checkpoint_number, compat.integral_types): 731 checkpoint_number = training_util.global_step( 732 sess=session, global_step_tensor=checkpoint_number) 733 prefix = "%s-%d" % (self._prefix, checkpoint_number) 734 save_path = self._checkpoint.write(prefix) 735 timestamp = time.time() 736 # If this is an overwritten checkpoint we were previously tracking, delete 737 # and reinsert it to make sure it goes to the end of the queue. 738 if save_path in self._maybe_delete: 739 del self._maybe_delete[save_path] 740 self._maybe_delete[save_path] = timestamp 741 self._latest_checkpoint = save_path 742 # Before deleting anything we update the Checkpoint proto with the new 743 # checkpoint. We'll go back and correct it after cleaning up old files, but 744 # a preemption while deleting will be more likely to see the new checkpoint 745 # this way. 746 self._record_state() 747 self._sweep() 748 # Write out the Checkpoint proto a second time, now without the deleted 749 # checkpoints. 750 self._record_state() 751 return save_path 752