1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import os.path 24import re 25import time 26 27from google.protobuf import text_format 28 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.eager import context 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.lib.io import file_io 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.training import training_util 37from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 38from tensorflow.python.util import compat 39from tensorflow.python.util import deprecation 40from tensorflow.python.util.tf_export import tf_export 41 42 43def _evaluate(tensor): 44 """Returns the numpy value of a tensor.""" 45 if context.executing_eagerly(): 46 return tensor.numpy() 47 return ops.get_default_session().run(tensor) 48 49 50def _GetCheckpointFilename(save_dir, latest_filename): 51 """Returns a filename for storing the CheckpointState. 52 53 Args: 54 save_dir: The directory for saving and restoring checkpoints. 55 latest_filename: Name of the file in 'save_dir' that is used 56 to store the CheckpointState. 57 58 Returns: 59 The path of the file that contains the CheckpointState proto. 60 """ 61 if latest_filename is None: 62 latest_filename = "checkpoint" 63 return os.path.join(save_dir, latest_filename) 64 65 66@tf_export(v1=["train.generate_checkpoint_state_proto"]) 67def generate_checkpoint_state_proto(save_dir, 68 model_checkpoint_path, 69 all_model_checkpoint_paths=None, 70 all_model_checkpoint_timestamps=None, 71 last_preserved_timestamp=None): 72 """Generates a checkpoint state proto. 73 74 Args: 75 save_dir: Directory where the model was saved. 76 model_checkpoint_path: The checkpoint file. 77 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 78 checkpoints, sorted from oldest to newest. If this is a non-empty list, 79 the last element must be equal to model_checkpoint_path. These paths 80 are also saved in the CheckpointState proto. 81 all_model_checkpoint_timestamps: A list of floats, indicating the number of 82 seconds since the Epoch when each checkpoint was generated. 83 last_preserved_timestamp: A float, indicating the number of seconds since 84 the Epoch when the last preserved checkpoint was written, e.g. due to a 85 `keep_checkpoint_every_n_hours` parameter (see 86 `tf.train.CheckpointManager` for an implementation). 87 Returns: 88 CheckpointState proto with model_checkpoint_path and 89 all_model_checkpoint_paths updated to either absolute paths or 90 relative paths to the current save_dir. 91 92 Raises: 93 ValueError: If `all_model_checkpoint_timestamps` was provided but its length 94 does not match `all_model_checkpoint_paths`. 95 """ 96 if all_model_checkpoint_paths is None: 97 all_model_checkpoint_paths = [] 98 99 if (not all_model_checkpoint_paths or 100 all_model_checkpoint_paths[-1] != model_checkpoint_path): 101 logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", 102 model_checkpoint_path) 103 all_model_checkpoint_paths.append(model_checkpoint_path) 104 105 if (all_model_checkpoint_timestamps 106 and (len(all_model_checkpoint_timestamps) 107 != len(all_model_checkpoint_paths))): 108 raise ValueError( 109 ("Checkpoint timestamps, if provided, must match checkpoint paths (got " 110 "paths %s and timestamps %s)") 111 % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) 112 113 # Relative paths need to be rewritten to be relative to the "save_dir" 114 # if model_checkpoint_path already contains "save_dir". 115 if not os.path.isabs(save_dir): 116 if not os.path.isabs(model_checkpoint_path): 117 model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) 118 for i, p in enumerate(all_model_checkpoint_paths): 119 if not os.path.isabs(p): 120 all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) 121 122 coord_checkpoint_proto = CheckpointState( 123 model_checkpoint_path=model_checkpoint_path, 124 all_model_checkpoint_paths=all_model_checkpoint_paths, 125 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 126 last_preserved_timestamp=last_preserved_timestamp) 127 128 return coord_checkpoint_proto 129 130 131@deprecation.deprecated( 132 date=None, 133 instructions=("Use `tf.train.CheckpointManager` to manage checkpoints " 134 "rather than manually editing the Checkpoint proto.")) 135@tf_export(v1=["train.update_checkpoint_state"]) 136def update_checkpoint_state(save_dir, 137 model_checkpoint_path, 138 all_model_checkpoint_paths=None, 139 latest_filename=None, 140 all_model_checkpoint_timestamps=None, 141 last_preserved_timestamp=None): 142 """Updates the content of the 'checkpoint' file. 143 144 This updates the checkpoint file containing a CheckpointState 145 proto. 146 147 Args: 148 save_dir: Directory where the model was saved. 149 model_checkpoint_path: The checkpoint file. 150 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 151 checkpoints, sorted from oldest to newest. If this is a non-empty list, 152 the last element must be equal to model_checkpoint_path. These paths 153 are also saved in the CheckpointState proto. 154 latest_filename: Optional name of the checkpoint file. Default to 155 'checkpoint'. 156 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 157 seconds since the Epoch) indicating when the checkpoints in 158 `all_model_checkpoint_paths` were created. 159 last_preserved_timestamp: A float, indicating the number of seconds since 160 the Epoch when the last preserved checkpoint was written, e.g. due to a 161 `keep_checkpoint_every_n_hours` parameter (see 162 `tf.train.CheckpointManager` for an implementation). 163 Raises: 164 RuntimeError: If any of the model checkpoint paths conflict with the file 165 containing CheckpointSate. 166 """ 167 update_checkpoint_state_internal( 168 save_dir=save_dir, 169 model_checkpoint_path=model_checkpoint_path, 170 all_model_checkpoint_paths=all_model_checkpoint_paths, 171 latest_filename=latest_filename, 172 save_relative_paths=False, 173 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 174 last_preserved_timestamp=last_preserved_timestamp) 175 176 177@tf_export("__internal__.train.update_checkpoint_state", v1=[]) 178def update_checkpoint_state_internal(save_dir, 179 model_checkpoint_path, 180 all_model_checkpoint_paths=None, 181 latest_filename=None, 182 save_relative_paths=False, 183 all_model_checkpoint_timestamps=None, 184 last_preserved_timestamp=None): 185 """Updates the content of the 'checkpoint' file. 186 187 This updates the checkpoint file containing a CheckpointState 188 proto. 189 190 Args: 191 save_dir: Directory where the model was saved. 192 model_checkpoint_path: The checkpoint file. 193 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 194 checkpoints, sorted from oldest to newest. If this is a non-empty list, 195 the last element must be equal to model_checkpoint_path. These paths 196 are also saved in the CheckpointState proto. 197 latest_filename: Optional name of the checkpoint file. Default to 198 'checkpoint'. 199 save_relative_paths: If `True`, will write relative paths to the checkpoint 200 state file. 201 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 202 seconds since the Epoch) indicating when the checkpoints in 203 `all_model_checkpoint_paths` were created. 204 last_preserved_timestamp: A float, indicating the number of seconds since 205 the Epoch when the last preserved checkpoint was written, e.g. due to a 206 `keep_checkpoint_every_n_hours` parameter (see 207 `tf.train.CheckpointManager` for an implementation). 208 209 Raises: 210 RuntimeError: If any of the model checkpoint paths conflict with the file 211 containing CheckpointSate. 212 """ 213 # Writes the "checkpoint" file for the coordinator for later restoration. 214 coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) 215 if save_relative_paths: 216 if os.path.isabs(model_checkpoint_path): 217 rel_model_checkpoint_path = os.path.relpath( 218 model_checkpoint_path, save_dir) 219 else: 220 rel_model_checkpoint_path = model_checkpoint_path 221 rel_all_model_checkpoint_paths = [] 222 for p in all_model_checkpoint_paths: 223 if os.path.isabs(p): 224 rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) 225 else: 226 rel_all_model_checkpoint_paths.append(p) 227 ckpt = generate_checkpoint_state_proto( 228 save_dir, 229 rel_model_checkpoint_path, 230 all_model_checkpoint_paths=rel_all_model_checkpoint_paths, 231 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 232 last_preserved_timestamp=last_preserved_timestamp) 233 else: 234 ckpt = generate_checkpoint_state_proto( 235 save_dir, 236 model_checkpoint_path, 237 all_model_checkpoint_paths=all_model_checkpoint_paths, 238 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 239 last_preserved_timestamp=last_preserved_timestamp) 240 241 if coord_checkpoint_filename == ckpt.model_checkpoint_path: 242 raise RuntimeError("Save path '%s' conflicts with path used for " 243 "checkpoint state. Please use a different save path." % 244 model_checkpoint_path) 245 246 # Preventing potential read/write race condition by *atomically* writing to a 247 # file. 248 file_io.atomic_write_string_to_file(coord_checkpoint_filename, 249 text_format.MessageToString(ckpt)) 250 251 252@tf_export("train.get_checkpoint_state") 253def get_checkpoint_state(checkpoint_dir, latest_filename=None): 254 """Returns CheckpointState proto from the "checkpoint" file. 255 256 If the "checkpoint" file contains a valid CheckpointState 257 proto, returns it. 258 259 Args: 260 checkpoint_dir: The directory of checkpoints. 261 latest_filename: Optional name of the checkpoint file. Default to 262 'checkpoint'. 263 264 Returns: 265 A CheckpointState if the state was available, None 266 otherwise. 267 268 Raises: 269 ValueError: if the checkpoint read doesn't have model_checkpoint_path set. 270 """ 271 ckpt = None 272 coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, 273 latest_filename) 274 f = None 275 try: 276 # Check that the file exists before opening it to avoid 277 # many lines of errors from colossus in the logs. 278 if file_io.file_exists(coord_checkpoint_filename): 279 file_content = file_io.read_file_to_string( 280 coord_checkpoint_filename) 281 ckpt = CheckpointState() 282 text_format.Merge(file_content, ckpt) 283 if not ckpt.model_checkpoint_path: 284 raise ValueError("Invalid checkpoint state loaded from " 285 + checkpoint_dir) 286 # For relative model_checkpoint_path and all_model_checkpoint_paths, 287 # prepend checkpoint_dir. 288 if not os.path.isabs(ckpt.model_checkpoint_path): 289 ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, 290 ckpt.model_checkpoint_path) 291 for i, p in enumerate(ckpt.all_model_checkpoint_paths): 292 if not os.path.isabs(p): 293 ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) 294 except errors.OpError as e: 295 # It's ok if the file cannot be read 296 logging.warning("%s: %s", type(e).__name__, e) 297 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 298 return None 299 except text_format.ParseError as e: 300 logging.warning("%s: %s", type(e).__name__, e) 301 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 302 return None 303 finally: 304 if f: 305 f.close() 306 return ckpt 307 308 309def _prefix_to_checkpoint_path(prefix, format_version): 310 """Returns the pathname of a checkpoint file, given the checkpoint prefix. 311 312 For V1 checkpoint, simply returns the prefix itself (the data file). For V2, 313 returns the pathname to the index file. 314 315 Args: 316 prefix: a string, the prefix of a checkpoint. 317 format_version: the checkpoint format version that corresponds to the 318 prefix. 319 Returns: 320 The pathname of a checkpoint file, taking into account the checkpoint 321 format version. 322 """ 323 if format_version == saver_pb2.SaverDef.V2: 324 return prefix + ".index" # The index file identifies a checkpoint. 325 return prefix # Just the data file. 326 327 328@tf_export("train.latest_checkpoint") 329def latest_checkpoint(checkpoint_dir, latest_filename=None): 330 """Finds the filename of latest saved checkpoint file. 331 332 Gets the checkpoint state given the provided checkpoint_dir and looks for a 333 corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path. 334 The latest_filename argument is only applicable if you are saving checkpoint 335 using `v1.train.Saver.save` 336 337 338 See the [Training Checkpoints 339 Guide](https://www.tensorflow.org/guide/checkpoint) for more details and 340 examples.` 341 342 Args: 343 checkpoint_dir: Directory where the variables were saved. 344 latest_filename: Optional name for the protocol buffer file that 345 contains the list of most recent checkpoint filenames. 346 See the corresponding argument to `v1.train.Saver.save`. 347 348 Returns: 349 The full path to the latest checkpoint or `None` if no checkpoint was found. 350 """ 351 # Pick the latest checkpoint based on checkpoint state. 352 ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) 353 if ckpt and ckpt.model_checkpoint_path: 354 # Look for either a V2 path or a V1 path, with priority for V2. 355 v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 356 saver_pb2.SaverDef.V2) 357 v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 358 saver_pb2.SaverDef.V1) 359 if file_io.get_matching_files(v2_path) or file_io.get_matching_files( 360 v1_path): 361 return ckpt.model_checkpoint_path 362 else: 363 logging.error("Couldn't match files for checkpoint %s", 364 ckpt.model_checkpoint_path) 365 return None 366 367 368def checkpoint_exists_internal(checkpoint_prefix): 369 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 370 371 This is an internal function to check if a checkpoint exists, 372 since it takes into account the naming difference between V1 and V2 formats. 373 374 Args: 375 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 376 priority. Typically the result of `Saver.save()` or that of 377 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 378 V1/V2. 379 Returns: 380 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 381 """ 382 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 383 saver_pb2.SaverDef.V2) 384 if file_io.get_matching_files(pathname): 385 return True 386 elif file_io.get_matching_files(checkpoint_prefix): 387 return True 388 else: 389 return False 390 391 392@deprecation.deprecated( 393 date=None, 394 instructions="Use standard file APIs to check for files with this prefix.") 395@tf_export(v1=["train.checkpoint_exists"]) 396def checkpoint_exists(checkpoint_prefix): 397 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 398 399 This is the recommended way to check if a checkpoint exists, since it takes 400 into account the naming difference between V1 and V2 formats. 401 402 Args: 403 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 404 priority. Typically the result of `Saver.save()` or that of 405 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 406 V1/V2. 407 408 Returns: 409 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 410 """ 411 return checkpoint_exists_internal(checkpoint_prefix) 412 413 414@deprecation.deprecated( 415 date=None, 416 instructions="Use standard file utilities to get mtimes.") 417@tf_export(v1=["train.get_checkpoint_mtimes"]) 418def get_checkpoint_mtimes(checkpoint_prefixes): 419 """Returns the mtimes (modification timestamps) of the checkpoints. 420 421 Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files 422 exist, collect their mtime. Both V2 and V1 checkpoints are considered, in 423 that priority. 424 425 This is the recommended way to get the mtimes, since it takes into account 426 the naming difference between V1 and V2 formats. 427 428 Note: If not all checkpoints exist, the length of the returned mtimes list 429 will be smaller than the length of `checkpoint_prefixes` list, so mapping 430 checkpoints to corresponding mtimes will not be possible. 431 432 Args: 433 checkpoint_prefixes: a list of checkpoint paths, typically the results of 434 `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of 435 sharded/non-sharded or V1/V2. 436 Returns: 437 A list of mtimes (in microseconds) of the found checkpoints. 438 """ 439 mtimes = [] 440 441 def match_maybe_append(pathname): 442 fnames = file_io.get_matching_files(pathname) 443 if fnames: 444 mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) 445 return True 446 return False 447 448 for checkpoint_prefix in checkpoint_prefixes: 449 # Tries V2's metadata file first. 450 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 451 saver_pb2.SaverDef.V2) 452 if match_maybe_append(pathname): 453 continue 454 # Otherwise, tries V1, where the prefix is the complete pathname. 455 match_maybe_append(checkpoint_prefix) 456 457 return mtimes 458 459 460@deprecation.deprecated( 461 date=None, 462 instructions="Use standard file APIs to delete files with this prefix.") 463@tf_export(v1=["train.remove_checkpoint"]) 464def remove_checkpoint(checkpoint_prefix, 465 checkpoint_format_version=saver_pb2.SaverDef.V2, 466 meta_graph_suffix="meta"): 467 """Removes a checkpoint given by `checkpoint_prefix`. 468 469 Args: 470 checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result 471 of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of 472 sharded/non-sharded or V1/V2. 473 checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to 474 `SaverDef.V2`. 475 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 476 """ 477 _delete_file_if_exists( 478 meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) 479 if checkpoint_format_version == saver_pb2.SaverDef.V2: 480 # V2 has a metadata file and some data files. 481 _delete_file_if_exists(checkpoint_prefix + ".index") 482 _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") 483 else: 484 # V1, Legacy. Exact match on the data file. 485 _delete_file_if_exists(checkpoint_prefix) 486 487 488def _delete_file_if_exists(filespec): 489 """Deletes files matching `filespec`.""" 490 for pathname in file_io.get_matching_files(filespec): 491 try: 492 file_io.delete_file(pathname) 493 except errors.NotFoundError: 494 logging.warning( 495 "Hit NotFoundError when deleting '%s', possibly because another " 496 "process/thread is also deleting/moving the same file", pathname) 497 498 499def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): 500 """Returns the meta graph filename. 501 502 Args: 503 checkpoint_filename: Name of the checkpoint file. 504 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 505 506 Returns: 507 MetaGraph file name. 508 """ 509 # If the checkpoint_filename is sharded, the checkpoint_filename could 510 # be of format model.ckpt-step#-?????-of-shard#. For example, 511 # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. 512 basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) 513 suffixed_filename = ".".join([basename, meta_graph_suffix]) 514 return suffixed_filename 515 516 517# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? 518@tf_export("train.CheckpointManager") 519class CheckpointManager(object): 520 """Manages multiple checkpoints by keeping some and deleting unneeded ones. 521 522 Example usage: 523 524 ```python 525 import tensorflow as tf 526 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 527 manager = tf.train.CheckpointManager( 528 checkpoint, directory="/tmp/model", max_to_keep=5) 529 status = checkpoint.restore(manager.latest_checkpoint) 530 while True: 531 # train 532 manager.save() 533 ``` 534 535 `CheckpointManager` preserves its own state across instantiations (see the 536 `__init__` documentation for details). Only one should be active in a 537 particular directory at a time. 538 """ 539 540 def __init__(self, 541 checkpoint, 542 directory, 543 max_to_keep, 544 keep_checkpoint_every_n_hours=None, 545 checkpoint_name="ckpt", 546 step_counter=None, 547 checkpoint_interval=None, 548 init_fn=None): 549 """Configure a `CheckpointManager` for use in `directory`. 550 551 If a `CheckpointManager` was previously used in `directory`, its 552 state will be restored. This includes the list of managed checkpoints and 553 the timestamp bookkeeping necessary to support 554 `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` 555 will be the same as the previous `CheckpointManager`, including cleaning up 556 existing checkpoints if appropriate. 557 558 Checkpoints are only considered for deletion just after a new checkpoint has 559 been added. At that point, `max_to_keep` checkpoints will remain in an 560 "active set". Once a checkpoint is preserved by 561 `keep_checkpoint_every_n_hours` it will not be deleted by this 562 `CheckpointManager` or any future `CheckpointManager` instantiated in 563 `directory` (regardless of the new setting of 564 `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the 565 active set may be deleted by this `CheckpointManager` or a future 566 `CheckpointManager` instantiated in `directory` (subject to its 567 `max_to_keep` and `keep_checkpoint_every_n_hours` settings). 568 569 `CheckpointManager` can be also used for initializing the model if 570 there is no checkpoints for restoring in `directory`. An example usage is: 571 572 >>> import tempfile 573 574 >>> tmp_dir = tempfile.mkdtemp() 575 >>> checkpoint = tf.train.Checkpoint() 576 >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init')) 577 578 >>> def init_fn(): 579 ... # Partially restore the checkpoint from `init_path`. 580 ... checkpoint.restore(init_path) 581 582 >>> manager = tf.train.CheckpointManager( 583 ... checkpoint, 584 ... directory=os.path.join(tmp_dir, 'ckpt'), 585 ... max_to_keep=None, 586 ... init_fn=init_fn) 587 >>> # `restore_or_initialize` will call `init_fn` if there is no existing 588 >>> # checkpoint in `directory`. 589 >>> manager.restore_or_initialize() 590 591 Args: 592 checkpoint: The `tf.train.Checkpoint` instance to save and manage 593 checkpoints for. 594 directory: The path to a directory in which to write checkpoints. A 595 special file named "checkpoint" is also written to this directory (in a 596 human-readable text format) which contains the state of the 597 `CheckpointManager`. 598 max_to_keep: An integer, the number of checkpoints to keep. Unless 599 preserved by `keep_checkpoint_every_n_hours`, checkpoints will be 600 deleted from the active set, oldest first, until only `max_to_keep` 601 checkpoints remain. If `None`, no checkpoints are deleted and everything 602 stays in the active set. Note that `max_to_keep=None` will keep all 603 checkpoint paths in memory and in the checkpoint state protocol buffer 604 on disk. 605 keep_checkpoint_every_n_hours: Upon removal from the active set, a 606 checkpoint will be preserved if it has been at least 607 `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The 608 default setting of `None` does not preserve any checkpoints in this way. 609 checkpoint_name: Custom name for the checkpoint file. 610 step_counter: A `tf.Variable` instance for checking the current step 611 counter value, in case users want to save checkpoints every N steps. 612 checkpoint_interval: An integer, indicates the minimum step interval 613 between two checkpoints. 614 init_fn: Callable. A function to do customized intialization if no 615 checkpoints are in the directory. 616 617 Raises: 618 ValueError: If `max_to_keep` is not a positive integer. 619 """ 620 self._checkpoint = checkpoint 621 self._save_counter_assign = None 622 if max_to_keep is not None and max_to_keep <= 0: 623 raise ValueError( 624 ("Expected a positive integer or `None` for `max_to_keep`, " 625 "got %d.") 626 % (max_to_keep,)) 627 self._max_to_keep = max_to_keep 628 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 629 self._directory = directory 630 self._checkpoint_prefix = os.path.join(directory, checkpoint_name) 631 self._init_fn = init_fn 632 633 if checkpoint_interval is not None: 634 if step_counter is None: 635 raise ValueError("`step_counter` should be passed if " 636 "`checkpoint_interval` is not None.") 637 self._last_checkpoint_step = None 638 self._step_counter = step_counter 639 self._checkpoint_interval = checkpoint_interval 640 641 recovered_state = get_checkpoint_state(directory) 642 current_clock = time.time() 643 self._maybe_delete = collections.OrderedDict() 644 if recovered_state is None: 645 self._latest_checkpoint = None 646 # Set the clock back slightly to avoid race conditions when quickly 647 # re-creating a CheckpointManager. 648 self._last_preserved_timestamp = current_clock - 1. 649 else: 650 self._latest_checkpoint = recovered_state.model_checkpoint_path 651 self._last_preserved_timestamp = recovered_state.last_preserved_timestamp 652 if current_clock < self._last_preserved_timestamp: 653 # Time seems to have reversed itself. In addition to this warning, we'll 654 # min() saved checkpoint timestamps with the current time to ensure that 655 # old checkpoints don't get deleted accidentally. 656 logging.warning( 657 ("time.time() returned a value %f seconds behind the last " 658 "preserved checkpoint timestamp.") 659 % (self._last_preserved_timestamp - current_clock,)) 660 self._last_preserved_timestamp = current_clock 661 all_timestamps = recovered_state.all_model_checkpoint_timestamps 662 all_paths = recovered_state.all_model_checkpoint_paths 663 del recovered_state # Uses modified values from now on 664 if not all_timestamps: 665 all_timestamps = [self._last_preserved_timestamp] * len(all_paths) 666 667 for filename, timestamp in zip(all_paths, all_timestamps): 668 timestamp = min(timestamp, current_clock) 669 if timestamp > self._last_preserved_timestamp: 670 self._maybe_delete[filename] = timestamp 671 672 @property 673 def directory(self): 674 return self._directory 675 676 @property 677 def checkpoint_interval(self): 678 return self._checkpoint_interval 679 680 @property 681 def latest_checkpoint(self): 682 """The prefix of the most recent checkpoint in `directory`. 683 684 Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is 685 the constructor argument to `CheckpointManager`. 686 687 Suitable for passing to `tf.train.Checkpoint.restore` to resume training. 688 689 Returns: 690 The checkpoint prefix. If there are no checkpoints, returns `None`. 691 """ 692 return self._latest_checkpoint 693 694 @property 695 def checkpoints(self): 696 """A list of managed checkpoints. 697 698 Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not 699 show up in this list (to avoid ever-growing filename lists). 700 701 Returns: 702 A list of filenames, sorted from oldest to newest. 703 """ 704 return list(self._maybe_delete.keys()) 705 706 def _sweep(self): 707 """Deletes or preserves managed checkpoints.""" 708 if not self._max_to_keep: 709 # Does not update self._last_preserved_timestamp, since everything is kept 710 # in the active set. 711 return 712 while len(self._maybe_delete) > self._max_to_keep: 713 filename, timestamp = self._maybe_delete.popitem(last=False) 714 # Even if we're keeping this checkpoint due to 715 # keep_checkpoint_every_n_hours, we won't reference it to avoid 716 # infinitely-growing CheckpointState protos. 717 if (self._keep_checkpoint_every_n_hours 718 and (timestamp - self._keep_checkpoint_every_n_hours * 3600. 719 >= self._last_preserved_timestamp)): 720 self._last_preserved_timestamp = timestamp 721 continue 722 _delete_file_if_exists(filename + ".index") 723 _delete_file_if_exists(filename + ".data-?????-of-?????") 724 725 def _record_state(self): 726 """Saves the `CheckpointManager`'s state in `directory`.""" 727 filenames, timestamps = zip(*self._maybe_delete.items()) 728 update_checkpoint_state_internal( 729 self._directory, 730 model_checkpoint_path=self.latest_checkpoint, 731 all_model_checkpoint_paths=filenames, 732 all_model_checkpoint_timestamps=timestamps, 733 last_preserved_timestamp=self._last_preserved_timestamp, 734 save_relative_paths=True) 735 736 @property 737 def _prefix(self): 738 """A common prefix for all checkpoints saved with this manager. 739 740 For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, 741 `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be 742 numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each 743 checkpoint has several associated files 744 (e.g. `"/tmp/tf-model/ckpt-2.index"`). 745 746 Returns: 747 A string prefix. 748 """ 749 return self._checkpoint_prefix 750 751 @property 752 def checkpoint(self): 753 """Returns the `tf.train.Checkpoint` object.""" 754 return self._checkpoint 755 756 def save(self, checkpoint_number=None, check_interval=True, options=None): 757 """Creates a new checkpoint and manages it. 758 759 Args: 760 checkpoint_number: An optional integer, or an integer-dtype `Variable` or 761 `Tensor`, used to number the checkpoint. If `None` (default), 762 checkpoints are numbered using `checkpoint.save_counter`. Even if 763 `checkpoint_number` is provided, `save_counter` is still incremented. A 764 user-provided `checkpoint_number` is not incremented even if it is a 765 `Variable`. 766 check_interval: An optional boolean. The argument is only effective when 767 `checkpoint_interval` is passed into the manager. If `True`, the manager 768 will only save the checkpoint if the interval between checkpoints is 769 larger than `checkpoint_interval`. Otherwise it will always save the 770 checkpoint unless a checkpoint has already been saved for the current 771 step. 772 options: Optional `tf.train.CheckpointOptions` object. This argument only 773 works with TF2 checkpoint objects. For example, options = 774 tf.saved_model.SaveOptions(experimental_io_device='/job:localhost') 775 776 Returns: 777 The path to the new checkpoint. It is also recorded in the `checkpoints` 778 and `latest_checkpoint` properties. `None` if no checkpoint is saved. 779 """ 780 if self._checkpoint_interval is not None: 781 current_step = _evaluate(self._step_counter) 782 if self._last_checkpoint_step is not None: 783 if current_step == self._last_checkpoint_step: 784 return None 785 if check_interval and current_step < ( 786 self._last_checkpoint_step + self._checkpoint_interval): 787 return None 788 self._last_checkpoint_step = current_step 789 790 # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge 791 # slightly with a custom numbering option. 792 if context.executing_eagerly(): 793 save_counter = self._checkpoint.save_counter 794 save_counter.assign_add(1) 795 session = None 796 else: 797 session = ops.get_default_session() 798 799 def _initializing_creator(next_creator, **kwargs): 800 """Initialize the save counter if it has been newly created.""" 801 v = next_creator(**kwargs) 802 session.run(v.initializer) 803 return v 804 805 with variable_scope.variable_creator_scope(_initializing_creator): 806 save_counter = self._checkpoint.save_counter 807 if self._save_counter_assign is None: 808 self._save_counter_assign = save_counter.assign_add(1, read_value=False) 809 session.run(self._save_counter_assign) 810 if checkpoint_number is None: 811 checkpoint_number = save_counter 812 if not isinstance(checkpoint_number, compat.integral_types): 813 checkpoint_number = training_util.global_step( 814 sess=session, global_step_tensor=checkpoint_number) 815 prefix = "%s-%d" % (self._prefix, checkpoint_number) 816 if options is None: 817 save_path = self._checkpoint.write(prefix) 818 else: 819 save_path = self._checkpoint.write(prefix, options=options) 820 timestamp = time.time() 821 # If this is an overwritten checkpoint we were previously tracking, delete 822 # and reinsert it to make sure it goes to the end of the queue. 823 if save_path in self._maybe_delete: 824 del self._maybe_delete[save_path] 825 self._maybe_delete[save_path] = timestamp 826 self._latest_checkpoint = save_path 827 # Before deleting anything we update the Checkpoint proto with the new 828 # checkpoint. We'll go back and correct it after cleaning up old files, but 829 # a preemption while deleting will be more likely to see the new checkpoint 830 # this way. 831 self._record_state() 832 self._sweep() 833 # Write out the Checkpoint proto a second time, now without the deleted 834 # checkpoints. 835 self._record_state() 836 return save_path 837 838 def restore_or_initialize(self): 839 """Restore items in `checkpoint` from the latest checkpoint file. 840 841 This method will first try to restore from the most recent checkpoint in 842 `directory`. If no checkpoints exist in `directory`, and `init_fn` is 843 specified, this method will call `init_fn` to do customized 844 initialization. This can be used to support initialization from pretrained 845 models. 846 847 Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return 848 a load status object that users can run assertions on 849 (e.g. assert_consumed()). Thus to run assertions, users should directly use 850 `tf.train.Checkpoint.restore()` method. 851 852 Returns: 853 The restored checkpoint path if the lastest checkpoint is found and 854 restored. Otherwise None. 855 """ 856 if self._latest_checkpoint is not None: 857 self._checkpoint.restore(self._latest_checkpoint) 858 if self._checkpoint_interval is not None: 859 self._last_checkpoint_step = _evaluate(self._step_counter) 860 return self._latest_checkpoint 861 862 if self._init_fn is not None: 863 self._init_fn() 864 return None 865