1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables. 18 19Symbols in this file are deprecated. See replacements in 20tensorflow/python/training/trackable and tensorflow/python/training/saving. 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27import os.path 28import time 29import uuid 30 31import numpy as np 32from tensorflow.core.protobuf import meta_graph_pb2 33from tensorflow.core.protobuf import saver_pb2 34from tensorflow.core.protobuf import trackable_object_graph_pb2 35from tensorflow.python.client import session 36from tensorflow.python.eager import context 37from tensorflow.python.framework import constant_op 38from tensorflow.python.framework import device as pydev 39from tensorflow.python.framework import errors 40from tensorflow.python.framework import meta_graph 41from tensorflow.python.framework import ops 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import gen_io_ops 45from tensorflow.python.ops import io_ops 46from tensorflow.python.ops import string_ops 47from tensorflow.python.ops import variables 48from tensorflow.python.platform import gfile 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.training import checkpoint_management 51from tensorflow.python.training import py_checkpoint_reader 52from tensorflow.python.training import training_util 53from tensorflow.python.training.saving import saveable_object 54from tensorflow.python.training.saving import saveable_object_util 55from tensorflow.python.training.tracking import base as trackable 56from tensorflow.python.util import compat 57from tensorflow.python.util.tf_export import tf_export 58 59# TODO(allenl): Remove these aliases once all users are migrated off. 60get_checkpoint_state = checkpoint_management.get_checkpoint_state 61update_checkpoint_state = checkpoint_management.update_checkpoint_state 62generate_checkpoint_state_proto = ( 63 checkpoint_management.generate_checkpoint_state_proto) 64latest_checkpoint = checkpoint_management.latest_checkpoint 65checkpoint_exists = checkpoint_management.checkpoint_exists 66get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes 67remove_checkpoint = checkpoint_management.remove_checkpoint 68 69 70class BaseSaverBuilder(object): 71 """Base class for Savers. 72 73 Can be extended to create different Ops. 74 """ 75 76 SaveSpec = saveable_object.SaveSpec 77 SaveableObject = saveable_object.SaveableObject 78 79 # Aliases for code which was moved but still has lots of users. 80 VariableSaveable = saveable_object_util.ReferenceVariableSaveable 81 ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable 82 83 def __init__(self, write_version=saver_pb2.SaverDef.V2): 84 self._write_version = write_version 85 86 def save_op(self, filename_tensor, saveables): 87 """Create an Op to save 'saveables'. 88 89 This is intended to be overridden by subclasses that want to generate 90 different Ops. 91 92 Args: 93 filename_tensor: String Tensor. 94 saveables: A list of BaseSaverBuilder.SaveableObject objects. 95 96 Returns: 97 An Operation that save the variables. 98 99 Raises: 100 RuntimeError: (implementation detail) if "self._write_version" is an 101 unexpected value. 102 """ 103 # pylint: disable=protected-access 104 tensor_names = [] 105 tensors = [] 106 tensor_slices = [] 107 for saveable in saveables: 108 for spec in saveable.specs: 109 tensor_names.append(spec.name) 110 tensors.append(spec.tensor) 111 tensor_slices.append(spec.slice_spec) 112 if self._write_version == saver_pb2.SaverDef.V1: 113 return io_ops._save( 114 filename=filename_tensor, 115 tensor_names=tensor_names, 116 tensors=tensors, 117 tensor_slices=tensor_slices) 118 elif self._write_version == saver_pb2.SaverDef.V2: 119 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix 120 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". 121 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, 122 tensors) 123 else: 124 raise RuntimeError("Unexpected write_version: " + self._write_version) 125 126 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 127 restore_sequentially): 128 """Restore all tensors contained in saveables. 129 130 By default, this issues separate calls to `restore_op` for each saveable. 131 Subclasses may override to load multiple saveables in a single call. 132 133 Args: 134 filename_tensor: String Tensor. 135 saveables: List of BaseSaverBuilder.SaveableObject objects. 136 preferred_shard: Int. Shard to open first when loading a sharded file. 137 restore_sequentially: Unused. Bool. If true, each restore is sequential. 138 139 Returns: 140 A list of Tensors resulting from reading 'saveable' from 141 'filename'. 142 143 """ 144 del restore_sequentially 145 all_tensors = [] 146 for saveable in saveables: 147 if saveable.device: 148 device = saveable_object_util.set_cpu0(saveable.device) 149 else: 150 device = None 151 with ops.device(device): 152 all_tensors.extend( 153 self.restore_op(filename_tensor, saveable, preferred_shard)) 154 return all_tensors 155 156 # pylint: disable=unused-argument 157 def restore_op(self, filename_tensor, saveable, preferred_shard): 158 """Create ops to restore 'saveable'. 159 160 This is intended to be overridden by subclasses that want to generate 161 different Ops. 162 163 Args: 164 filename_tensor: String Tensor. 165 saveable: A BaseSaverBuilder.SaveableObject object. 166 preferred_shard: Int. Shard to open first when loading a sharded file. 167 168 Returns: 169 A list of Tensors resulting from reading 'saveable' from 170 'filename'. 171 """ 172 # pylint: disable=protected-access 173 tensors = [] 174 for spec in saveable.specs: 175 tensors.append( 176 io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec], 177 [spec.dtype])[0]) 178 179 return tensors 180 181 # pylint: enable=unused-argument 182 183 def sharded_filename(self, filename_tensor, shard, num_shards): 184 """Append sharding information to a filename. 185 186 Args: 187 filename_tensor: A string tensor. 188 shard: Integer. The shard for the filename. 189 num_shards: An int Tensor for the number of shards. 190 191 Returns: 192 A string tensor. 193 """ 194 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 195 196 def _AddSaveOps(self, filename_tensor, saveables): 197 """Add ops to save variables that are on the same shard. 198 199 Args: 200 filename_tensor: String Tensor. 201 saveables: A list of SaveableObject objects. 202 203 Returns: 204 A tensor with the filename used to save. 205 """ 206 save = self.save_op(filename_tensor, saveables) 207 return control_flow_ops.with_dependencies([save], filename_tensor) 208 209 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): 210 """Add ops to save the params per shard, for the V2 format. 211 212 Note that the sharded save procedure for the V2 format is different from 213 V1: there is a special "merge" step that merges the small metadata produced 214 from each device. 215 216 Args: 217 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*, 218 but as a prefix of a V2 checkpoint; 219 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as 220 returned by _GroupByDevices(). 221 222 Returns: 223 An op to save the variables, which, when evaluated, returns the prefix 224 "<user-fed prefix>" only and does not include the sharded spec suffix. 225 """ 226 # IMPLEMENTATION DETAILS: most clients should skip. 227 # 228 # Suffix for any well-formed "checkpoint_prefix", when sharded. 229 # Transformations: 230 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 231 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>. 232 # 233 # Example: 234 # During runtime, a temporary directory is first created, which contains 235 # files 236 # 237 # <train dir>/myckpt_temp/ 238 # part-?????-of-?????{.index, .data-00000-of-00001} 239 # 240 # Before .save() finishes, they will be (hopefully, atomically) renamed to 241 # 242 # <train dir>/ 243 # myckpt{.index, .data-?????-of-?????} 244 # 245 # Users only need to interact with the user-specified prefix, which is 246 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 247 # prefix directly, instead of any physical pathname. (On failure and 248 # subsequent restore, an outdated and orphaned temporary directory can be 249 # safely removed.) 250 _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex 251 tmp_checkpoint_prefix = string_ops.string_join( 252 [checkpoint_prefix, _SHARDED_SUFFIX]) 253 254 num_shards = len(per_device) 255 sharded_saves = [] 256 sharded_prefixes = [] 257 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 258 last_device = None 259 for shard, (device, saveables) in enumerate(per_device): 260 last_device = device 261 with ops.device(saveable_object_util.set_cpu0(device)): 262 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, 263 num_shards_tensor) 264 sharded_prefixes.append(sharded_filename) 265 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 266 267 with ops.control_dependencies([x.op for x in sharded_saves]): 268 # Co-locates the merge step with the last device. 269 with ops.device(saveable_object_util.set_cpu0(last_device)): 270 # V2 format write path consists of a metadata merge step. Once merged, 271 # attempts to delete the temporary directory, "<user-fed prefix>_temp". 272 merge_step = gen_io_ops.merge_v2_checkpoints( 273 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) 274 with ops.control_dependencies([merge_step]): 275 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the 276 # sharded spec suffix. 277 return array_ops.identity(checkpoint_prefix) 278 279 def _AddShardedSaveOps(self, filename_tensor, per_device): 280 """Add ops to save the params per shard. 281 282 Args: 283 filename_tensor: a scalar String Tensor. 284 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as 285 returned by _GroupByDevices(). 286 287 Returns: 288 An op to save the variables. 289 """ 290 if self._write_version == saver_pb2.SaverDef.V2: 291 return self._AddShardedSaveOpsForV2(filename_tensor, per_device) 292 293 num_shards = len(per_device) 294 sharded_saves = [] 295 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 296 for shard, (device, saveables) in enumerate(per_device): 297 with ops.device(device): 298 sharded_filename = self.sharded_filename(filename_tensor, shard, 299 num_shards_tensor) 300 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 301 # Return the sharded name for the save path. 302 with ops.control_dependencies([x.op for x in sharded_saves]): 303 return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor) 304 305 def _AddRestoreOps(self, 306 filename_tensor, 307 saveables, 308 restore_sequentially, 309 reshape, 310 preferred_shard=-1, 311 name="restore_all"): 312 """Add operations to restore saveables. 313 314 Args: 315 filename_tensor: Tensor for the path of the file to load. 316 saveables: A list of SaveableObject objects. 317 restore_sequentially: True if we want to restore variables sequentially 318 within a shard. 319 reshape: True if we want to reshape loaded tensors to the shape of the 320 corresponding variable. 321 preferred_shard: Shard to open first when loading a sharded file. 322 name: Name for the returned op. 323 324 Returns: 325 An Operation that restores the variables. 326 """ 327 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, 328 restore_sequentially) 329 330 assign_ops = [] 331 idx = 0 332 # Load and optionally reshape on the CPU, as string tensors are not 333 # available on the GPU. 334 # TODO(touts): Re-enable restore on GPU when we can support annotating 335 # string tensors as "HostMemory" inputs. 336 for saveable in saveables: 337 shapes = None 338 if reshape: 339 # Compute the shapes, let the restore op decide if and how to do 340 # the reshape. 341 shapes = [] 342 for spec in saveable.specs: 343 v = spec.tensor 344 shape = v.get_shape() 345 if not shape.is_fully_defined(): 346 shape = array_ops.shape(v) 347 shapes.append(shape) 348 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] 349 idx += len(saveable.specs) 350 assign_ops.append(saveable.restore(saveable_tensors, shapes)) 351 352 # Create a Noop that has control dependencies from all the updates. 353 return control_flow_ops.group(*assign_ops, name=name) 354 355 def _AddShardedRestoreOps(self, filename_tensor, per_device, 356 restore_sequentially, reshape): 357 """Add Ops to restore variables from multiple devices. 358 359 Args: 360 filename_tensor: Tensor for the path of the file to load. 361 per_device: A list of (device, SaveableObject) pairs, as returned by 362 _GroupByDevices(). 363 restore_sequentially: True if we want to restore variables sequentially 364 within a shard. 365 reshape: True if we want to reshape loaded tensors to the shape of the 366 corresponding variable. 367 368 Returns: 369 An Operation that restores the variables. 370 """ 371 sharded_restores = [] 372 for shard, (device, saveables) in enumerate(per_device): 373 with ops.device(device): 374 sharded_restores.append( 375 self._AddRestoreOps( 376 filename_tensor, 377 saveables, 378 restore_sequentially, 379 reshape, 380 preferred_shard=shard, 381 name="restore_shard")) 382 return control_flow_ops.group(*sharded_restores, name="restore_all") 383 384 def _GroupByDevices(self, saveables): 385 """Group Variable tensor slices per device. 386 387 TODO(touts): Make sure that all the devices found are on different 388 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. 389 It can happen if the devices are unspecified. 390 391 Args: 392 saveables: A list of BaseSaverBuilder.SaveableObject objects. 393 394 Returns: 395 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples. 396 The list is sorted by ascending device_name. 397 398 Raises: 399 ValueError: If the tensors of a saveable are on different devices. 400 """ 401 per_device = collections.defaultdict(lambda: []) 402 for saveable in saveables: 403 canonical_device = set( 404 pydev.canonical_name(spec.device) for spec in saveable.specs) 405 if len(canonical_device) != 1: 406 raise ValueError("All tensors of a saveable object must be " 407 "on the same device: %s" % saveable.name) 408 per_device[canonical_device.pop()].append(saveable) 409 return sorted(per_device.items(), key=lambda t: t[0]) 410 411 def build(self, 412 names_to_saveables, 413 reshape=False, 414 sharded=False, 415 max_to_keep=5, 416 keep_checkpoint_every_n_hours=10000.0, 417 name=None, 418 restore_sequentially=False, 419 filename="model"): 420 """Builds save/restore graph nodes or runs save/restore in eager mode. 421 422 Args: 423 names_to_saveables: A dictionary mapping name to a Variable or 424 SaveableObject. Each name will be associated with the corresponding 425 variable in the checkpoint. 426 reshape: If True, allow restoring parameters from a checkpoint that where 427 the parameters have a different shape. This is only needed when you try 428 to restore from a Dist-Belief checkpoint, and only some times. 429 sharded: If True, shard the checkpoints, one per device that has Variable 430 nodes. 431 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints 432 are created, old ones are deleted. If None or 0, no checkpoints are 433 deleted from the filesystem but only the last one is kept in the 434 `checkpoint` file. Presently the number is only roughly enforced. For 435 example in case of restarts more than max_to_keep checkpoints may be 436 kept. 437 keep_checkpoint_every_n_hours: How often checkpoints should be kept. 438 Defaults to 10,000 hours. 439 name: String. Optional name to use as a prefix when adding operations. 440 restore_sequentially: A Bool, which if true, causes restore of different 441 variables to happen sequentially within each device. 442 filename: If known at graph construction time, filename used for variable 443 loading/saving. If None, then the default name "model" will be used. 444 445 Returns: 446 A SaverDef proto. 447 448 Raises: 449 TypeError: If 'names_to_saveables' is not a dictionary mapping string 450 keys to variable Tensors. 451 ValueError: If any of the keys or values in 'names_to_saveables' is not 452 unique. 453 """ 454 return self._build_internal( 455 names_to_saveables=names_to_saveables, 456 reshape=reshape, 457 sharded=sharded, 458 max_to_keep=max_to_keep, 459 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 460 name=name, 461 restore_sequentially=restore_sequentially, 462 filename=filename) 463 464 def _build_internal(self, 465 names_to_saveables, 466 reshape=False, 467 sharded=False, 468 max_to_keep=5, 469 keep_checkpoint_every_n_hours=10000.0, 470 name=None, 471 restore_sequentially=False, 472 filename="model", 473 build_save=True, 474 build_restore=True): 475 """build() with option to only perform save and restore.""" 476 if not context.executing_eagerly() and (not build_save or 477 not build_restore): 478 raise ValueError("save and restore operations need to be built together " 479 " when eager execution is not enabled.") 480 481 saveables = saveable_object_util.validate_and_slice_inputs( 482 names_to_saveables) 483 if max_to_keep is None: 484 max_to_keep = 0 485 486 with ops.name_scope(name, "save", 487 [saveable.op for saveable in saveables]) as name: 488 # Add a placeholder string tensor for the filename. 489 filename_tensor = array_ops.placeholder_with_default( 490 filename or "model", shape=(), name="filename") 491 # Keep the name "Const" for backwards compatibility. 492 filename_tensor = array_ops.placeholder_with_default( 493 filename_tensor, shape=(), name="Const") 494 495 # Add the save ops. 496 if sharded: 497 per_device = self._GroupByDevices(saveables) 498 if build_save: 499 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) 500 if build_restore: 501 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, 502 restore_sequentially, reshape) 503 else: 504 if build_save: 505 save_tensor = self._AddSaveOps(filename_tensor, saveables) 506 if build_restore: 507 restore_op = self._AddRestoreOps(filename_tensor, saveables, 508 restore_sequentially, reshape) 509 510 # In the following use case, it's possible to have restore_ops be called 511 # something else: 512 # - Build inference graph and export a meta_graph. 513 # - Import the inference meta_graph 514 # - Extend the inference graph to a train graph. 515 # - Export a new meta_graph. 516 # Now the second restore_op will be called "restore_all_1". 517 # As such, comment out the assert for now until we know whether supporting 518 # such usage model makes sense. 519 # 520 # assert restore_op.name.endswith("restore_all"), restore_op.name 521 if context.executing_eagerly(): 522 # Store the tensor values to the tensor_names. 523 save_tensor_name = save_tensor.numpy() if build_save else "" 524 return saver_pb2.SaverDef( 525 filename_tensor_name=filename_tensor.numpy(), 526 save_tensor_name=save_tensor_name, 527 restore_op_name="", 528 max_to_keep=max_to_keep, 529 sharded=sharded, 530 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 531 version=self._write_version) 532 else: 533 graph = ops.get_default_graph() 534 # Do some sanity checking on collections containing 535 # PartitionedVariables. If a saved collection has a PartitionedVariable, 536 # the GraphDef needs to include concat ops to get the value (or there'll 537 # be a lookup error on load). 538 check_collection_list = graph.get_all_collection_keys() 539 for collection_type in check_collection_list: 540 for element in graph.get_collection(collection_type): 541 if isinstance(element, variables.PartitionedVariable): 542 try: 543 graph.get_operation_by_name(element.name) 544 except KeyError: 545 # Create a concat op for this PartitionedVariable. The user may 546 # not need it, but we'll try looking it up on MetaGraph restore 547 # since it's in a collection. 548 element.as_tensor() 549 return saver_pb2.SaverDef( 550 filename_tensor_name=filename_tensor.name, 551 save_tensor_name=save_tensor.name, 552 restore_op_name=restore_op.name, 553 max_to_keep=max_to_keep, 554 sharded=sharded, 555 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 556 version=self._write_version) 557 558 559class BulkSaverBuilder(BaseSaverBuilder): 560 """SaverBuilder with support for bulk restoring multiple saveables.""" 561 562 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 563 restore_sequentially): 564 565 # Ignored: bulk restore is internally sequential. 566 del restore_sequentially 567 restore_specs = [] 568 for saveable in saveables: 569 for spec in saveable.specs: 570 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 571 572 names, slices, dtypes = zip(*restore_specs) 573 # Load all tensors onto CPU 0 for compatibility with existing code. 574 with ops.device("cpu:0"): 575 return io_ops.restore_v2(filename_tensor, names, slices, dtypes) 576 577 578def _get_saver_or_default(): 579 """Returns the saver from SAVERS collection, or creates a default one. 580 581 This method is used by other members of the training module, such as 582 `Scaffold`, or `CheckpointSaverHook`. 583 584 Returns: 585 `Saver`. 586 587 Raises: 588 RuntimeError: If the SAVERS collection already has more than one items. 589 """ 590 collection_key = ops.GraphKeys.SAVERS 591 savers = ops.get_collection(collection_key) 592 if savers: 593 if len(savers) > 1: 594 raise RuntimeError( 595 "More than one item in collection {}. " 596 "Please indicate which one to use by passing it to the constructor." 597 .format(collection_key)) 598 return savers[0] 599 saver = Saver(sharded=True, allow_empty=True) 600 if saver is not None: 601 ops.add_to_collection(collection_key, saver) 602 return saver 603 604 605@tf_export(v1=["train.Saver"]) 606class Saver(object): 607 """Saves and restores variables. 608 609 See [Variables](https://tensorflow.org/guide/variables) 610 for an overview of variables, saving and restoring. 611 612 The `Saver` class adds ops to save and restore variables to and from 613 *checkpoints*. It also provides convenience methods to run these ops. 614 615 Checkpoints are binary files in a proprietary format which map variable names 616 to tensor values. The best way to examine the contents of a checkpoint is to 617 load it using a `Saver`. 618 619 Savers can automatically number checkpoint filenames with a provided counter. 620 This lets you keep multiple checkpoints at different steps while training a 621 model. For example you can number the checkpoint filenames with the training 622 step number. To avoid filling up disks, savers manage checkpoint files 623 automatically. For example, they can keep only the N most recent files, or 624 one checkpoint for every N hours of training. 625 626 You number checkpoint filenames by passing a value to the optional 627 `global_step` argument to `save()`: 628 629 ```python 630 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' 631 ... 632 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' 633 ``` 634 635 Additionally, optional arguments to the `Saver()` constructor let you control 636 the proliferation of checkpoint files on disk: 637 638 * `max_to_keep` indicates the maximum number of recent checkpoint files to 639 keep. As new files are created, older files are deleted. If None or 0, 640 no checkpoints are deleted from the filesystem but only the last one is 641 kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent 642 checkpoint files are kept.) 643 644 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent 645 `max_to_keep` checkpoint files, you might want to keep one checkpoint file 646 for every N hours of training. This can be useful if you want to later 647 analyze how a model progressed during a long training session. For 648 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep 649 one checkpoint file for every 2 hours of training. The default value of 650 10,000 hours effectively disables the feature. 651 652 Note that you still have to call the `save()` method to save the model. 653 Passing these arguments to the constructor will not save variables 654 automatically for you. 655 656 A training program that saves regularly looks like: 657 658 ```python 659 ... 660 # Create a saver. 661 saver = tf.compat.v1.train.Saver(...variables...) 662 # Launch the graph and train, saving the model every 1,000 steps. 663 sess = tf.compat.v1.Session() 664 for step in xrange(1000000): 665 sess.run(..training_op..) 666 if step % 1000 == 0: 667 # Append the step number to the checkpoint name: 668 saver.save(sess, 'my-model', global_step=step) 669 ``` 670 671 In addition to checkpoint files, savers keep a protocol buffer on disk with 672 the list of recent checkpoints. This is used to manage numbered checkpoint 673 files and by `latest_checkpoint()`, which makes it easy to discover the path 674 to the most recent checkpoint. That protocol buffer is stored in a file named 675 'checkpoint' next to the checkpoint files. 676 677 If you create several savers, you can specify a different filename for the 678 protocol buffer file in the call to `save()`. 679 """ 680 681 def __init__(self, 682 var_list=None, 683 reshape=False, 684 sharded=False, 685 max_to_keep=5, 686 keep_checkpoint_every_n_hours=10000.0, 687 name=None, 688 restore_sequentially=False, 689 saver_def=None, 690 builder=None, 691 defer_build=False, 692 allow_empty=False, 693 write_version=saver_pb2.SaverDef.V2, 694 pad_step_number=False, 695 save_relative_paths=False, 696 filename=None): 697 """Creates a `Saver`. 698 699 The constructor adds ops to save and restore variables. 700 701 `var_list` specifies the variables that will be saved and restored. It can 702 be passed as a `dict` or a list: 703 704 * A `dict` of names to variables: The keys are the names that will be 705 used to save or restore the variables in the checkpoint files. 706 * A list of variables: The variables will be keyed with their op name in 707 the checkpoint files. 708 709 For example: 710 711 ```python 712 v1 = tf.Variable(..., name='v1') 713 v2 = tf.Variable(..., name='v2') 714 715 # Pass the variables as a dict: 716 saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2}) 717 718 # Or pass them as a list. 719 saver = tf.compat.v1.train.Saver([v1, v2]) 720 # Passing a list is equivalent to passing a dict with the variable op names 721 # as keys: 722 saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]}) 723 ``` 724 725 Note: the newer `AutoTrackable` API is not supported by `Saver`. In this 726 case, the `tf.train.Checkpoint` class should be used. 727 728 The optional `reshape` argument, if `True`, allows restoring a variable from 729 a save file where the variable had a different shape, but the same number 730 of elements and type. This is useful if you have reshaped a variable and 731 want to reload it from an older checkpoint. 732 733 The optional `sharded` argument, if `True`, instructs the saver to shard 734 checkpoints per device. 735 736 Args: 737 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping 738 names to `SaveableObject`s. If `None`, defaults to the list of all 739 saveable objects. 740 reshape: If `True`, allows restoring parameters from a checkpoint where 741 the variables have a different shape. 742 sharded: If `True`, shard the checkpoints, one per device. 743 max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5. 744 keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to 745 10,000 hours. 746 name: String. Optional name to use as a prefix when adding operations. 747 restore_sequentially: A `Bool`, which if true, causes restore of different 748 variables to happen sequentially within each device. This can lower 749 memory usage when restoring very large models. 750 saver_def: Optional `SaverDef` proto to use instead of running the 751 builder. This is only useful for specialty code that wants to recreate a 752 `Saver` object for a previously built `Graph` that had a `Saver`. The 753 `saver_def` proto should be the one returned by the `as_saver_def()` 754 call of the `Saver` that was created for that `Graph`. 755 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. 756 Defaults to `BulkSaverBuilder()`. 757 defer_build: If `True`, defer adding the save and restore ops to the 758 `build()` call. In that case `build()` should be called before 759 finalizing the graph or using the saver. 760 allow_empty: If `False` (default) raise an error if there are no variables 761 in the graph. Otherwise, construct the saver anyway and make it a no-op. 762 write_version: controls what format to use when saving checkpoints. It 763 also affects certain filepath matching logic. The V2 format is the 764 recommended choice: it is much more optimized than V1 in terms of memory 765 required and latency incurred during restore. Regardless of this 766 flag, the Saver is able to restore from both V2 and V1 checkpoints. 767 pad_step_number: if True, pads the global step number in the checkpoint 768 filepaths to some fixed width (8 by default). This is turned off by 769 default. 770 save_relative_paths: If `True`, will write relative paths to the 771 checkpoint state file. This is needed if the user wants to copy the 772 checkpoint directory and reload from the copied directory. 773 filename: If known at graph construction time, filename used for variable 774 loading/saving. 775 776 Raises: 777 TypeError: If `var_list` is invalid. 778 ValueError: If any of the keys or values in `var_list` are not unique. 779 RuntimeError: If eager execution is enabled and`var_list` does not specify 780 a list of variables to save. 781 782 @compatibility(eager) 783 When eager execution is enabled, `var_list` must specify a `list` or `dict` 784 of variables to save. Otherwise, a `RuntimeError` will be raised. 785 786 Although Saver works in some cases when executing eagerly, it is 787 fragile. Please switch to `tf.train.Checkpoint` or 788 `tf.keras.Model.save_weights`, which perform a more robust object-based 789 saving. These APIs will load checkpoints written by `Saver`. 790 @end_compatibility 791 """ 792 if defer_build and var_list: 793 raise ValueError( 794 "If `var_list` is provided then build cannot be deferred. " 795 "Either set defer_build=False or var_list=None.") 796 if context.executing_eagerly(): 797 logging.warning( 798 "Saver is deprecated, please switch to tf.train.Checkpoint or " 799 "tf.keras.Model.save_weights for training checkpoints. When " 800 "executing eagerly variables do not necessarily have unique names, " 801 "and so the variable.name-based lookups Saver performs are " 802 "error-prone.") 803 if var_list is None: 804 raise RuntimeError( 805 "When eager execution is enabled, `var_list` must specify a list " 806 "or dict of variables to save") 807 self._var_list = var_list 808 self._reshape = reshape 809 self._sharded = sharded 810 self._max_to_keep = max_to_keep 811 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 812 self._name = name 813 self._restore_sequentially = restore_sequentially 814 self.saver_def = saver_def 815 self._builder = builder 816 self._is_built = False 817 self._allow_empty = allow_empty 818 self._is_empty = None 819 self._write_version = write_version 820 self._pad_step_number = pad_step_number 821 self._filename = filename 822 self._last_checkpoints = [] 823 self._checkpoints_to_be_deleted = [] 824 if context.executing_eagerly(): 825 self._next_checkpoint_time = ( 826 time.time() + self._keep_checkpoint_every_n_hours * 3600) 827 elif not defer_build: 828 self.build() 829 if self.saver_def: 830 self._check_saver_def() 831 self._write_version = self.saver_def.version 832 self._save_relative_paths = save_relative_paths 833 # For compatibility with object-based checkpoints, we may build a second 834 # Saver to read the renamed keys. 835 self._object_restore_saver = None 836 837 def build(self): 838 if context.executing_eagerly(): 839 raise RuntimeError("Use save/restore instead of build in eager mode.") 840 self._build(self._filename, build_save=True, build_restore=True) 841 842 def _build_eager(self, checkpoint_path, build_save, build_restore): 843 self._build( 844 checkpoint_path, build_save=build_save, build_restore=build_restore) 845 846 def _build(self, checkpoint_path, build_save, build_restore): 847 """Builds saver_def.""" 848 if not context.executing_eagerly(): 849 if self._is_built: 850 return 851 self._is_built = True 852 853 if not self.saver_def or context.executing_eagerly(): 854 if self._builder is None: 855 self._builder = BulkSaverBuilder(self._write_version) 856 857 if self._var_list is None: 858 # pylint: disable=protected-access 859 self._var_list = variables._all_saveable_objects() 860 if not self._var_list: 861 if self._allow_empty: 862 self._is_empty = True 863 return 864 else: 865 raise ValueError("No variables to save") 866 self._is_empty = False 867 868 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access 869 self._var_list, 870 reshape=self._reshape, 871 sharded=self._sharded, 872 max_to_keep=self._max_to_keep, 873 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, 874 name=self._name, 875 restore_sequentially=self._restore_sequentially, 876 filename=checkpoint_path, 877 build_save=build_save, 878 build_restore=build_restore) 879 elif self.saver_def and self._name: 880 # Since self._name is used as a name_scope by builder(), we are 881 # overloading the use of this field to represent the "import_scope" as 882 # well. 883 self.saver_def.filename_tensor_name = ops.prepend_name_scope( 884 self.saver_def.filename_tensor_name, self._name) 885 self.saver_def.save_tensor_name = ops.prepend_name_scope( 886 self.saver_def.save_tensor_name, self._name) 887 self.saver_def.restore_op_name = ops.prepend_name_scope( 888 self.saver_def.restore_op_name, self._name) 889 890 self._check_saver_def() 891 if not context.executing_eagerly(): 892 # Updates next checkpoint time. 893 # Set in __init__ when executing eagerly. 894 self._next_checkpoint_time = ( 895 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) 896 897 def _check_saver_def(self): 898 if not isinstance(self.saver_def, saver_pb2.SaverDef): 899 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % 900 self.saver_def) 901 if not context.executing_eagerly(): 902 if not self.saver_def.save_tensor_name: 903 raise ValueError("saver_def must specify the save_tensor_name: %s" % 904 str(self.saver_def)) 905 if not self.saver_def.restore_op_name: 906 raise ValueError("saver_def must specify the restore_op_name: %s" % 907 str(self.saver_def)) 908 909 def _CheckpointFilename(self, p): 910 """Returns the checkpoint filename given a `(filename, time)` pair. 911 912 Args: 913 p: (filename, time) pair. 914 915 Returns: 916 Checkpoint file name. 917 """ 918 name, _ = p 919 return name 920 921 def _RecordLastCheckpoint(self, latest_save_path): 922 """Manages the list of the latest checkpoints.""" 923 if not self.saver_def.max_to_keep: 924 return 925 # Remove first from list if the same name was used before. 926 for p in self._last_checkpoints: 927 if latest_save_path == self._CheckpointFilename(p): 928 self._last_checkpoints.remove(p) 929 # Append new path to list 930 self._last_checkpoints.append((latest_save_path, time.time())) 931 932 # If more than max_to_keep, remove oldest. 933 if len(self._last_checkpoints) > self.saver_def.max_to_keep: 934 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) 935 936 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): 937 """Deletes old checkpoints if necessary. 938 939 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are 940 over `max_to_keep`. They are going to be deleted. If 941 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint 942 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is 943 kept for every 0.5 hours of training; if `N` is 10, an additional 944 checkpoint is kept for every 10 hours of training. 945 946 Args: 947 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 948 """ 949 if self._checkpoints_to_be_deleted: 950 p = self._checkpoints_to_be_deleted.pop(0) 951 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we 952 # have reached N hours of training. 953 should_keep = p[1] > self._next_checkpoint_time 954 if should_keep: 955 self._next_checkpoint_time += ( 956 self.saver_def.keep_checkpoint_every_n_hours * 3600) 957 return 958 959 # Otherwise delete the files. 960 try: 961 checkpoint_management.remove_checkpoint( 962 self._CheckpointFilename(p), self.saver_def.version, 963 meta_graph_suffix) 964 except Exception as e: # pylint: disable=broad-except 965 logging.warning("Ignoring: %s", str(e)) 966 967 def as_saver_def(self): 968 """Generates a `SaverDef` representation of this saver. 969 970 Returns: 971 A `SaverDef` proto. 972 """ 973 return self.saver_def 974 975 def to_proto(self, export_scope=None): 976 """Converts this `Saver` to a `SaverDef` protocol buffer. 977 978 Args: 979 export_scope: Optional `string`. Name scope to remove. 980 981 Returns: 982 A `SaverDef` protocol buffer. 983 """ 984 if export_scope is None: 985 return self.saver_def 986 987 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and 988 self.saver_def.save_tensor_name.startswith(export_scope) and 989 self.saver_def.restore_op_name.startswith(export_scope)): 990 return None 991 992 saver_def = saver_pb2.SaverDef() 993 saver_def.CopyFrom(self.saver_def) 994 saver_def.filename_tensor_name = ops.strip_name_scope( 995 saver_def.filename_tensor_name, export_scope) 996 saver_def.save_tensor_name = ops.strip_name_scope( 997 saver_def.save_tensor_name, export_scope) 998 saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name, 999 export_scope) 1000 return saver_def 1001 1002 @staticmethod 1003 def from_proto(saver_def, import_scope=None): 1004 """Returns a `Saver` object created from `saver_def`. 1005 1006 Args: 1007 saver_def: a `SaverDef` protocol buffer. 1008 import_scope: Optional `string`. Name scope to use. 1009 1010 Returns: 1011 A `Saver` built from saver_def. 1012 """ 1013 return Saver(saver_def=saver_def, name=import_scope) 1014 1015 @property 1016 def last_checkpoints(self): 1017 """List of not-yet-deleted checkpoint filenames. 1018 1019 You can pass any of the returned values to `restore()`. 1020 1021 Returns: 1022 A list of checkpoint filenames, sorted from oldest to newest. 1023 """ 1024 return list(self._CheckpointFilename(p) for p in self._last_checkpoints) 1025 1026 def set_last_checkpoints(self, last_checkpoints): 1027 """DEPRECATED: Use set_last_checkpoints_with_time. 1028 1029 Sets the list of old checkpoint filenames. 1030 1031 Args: 1032 last_checkpoints: A list of checkpoint filenames. 1033 1034 Raises: 1035 AssertionError: If last_checkpoints is not a list. 1036 """ 1037 assert isinstance(last_checkpoints, list) 1038 # We use a timestamp of +inf so that this checkpoint will never be 1039 # deleted. This is both safe and backwards compatible to a previous 1040 # version of the code which used s[1] as the "timestamp". 1041 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] 1042 1043 def set_last_checkpoints_with_time(self, last_checkpoints_with_time): 1044 """Sets the list of old checkpoint filenames and timestamps. 1045 1046 Args: 1047 last_checkpoints_with_time: A list of tuples of checkpoint filenames and 1048 timestamps. 1049 1050 Raises: 1051 AssertionError: If last_checkpoints_with_time is not a list. 1052 """ 1053 assert isinstance(last_checkpoints_with_time, list) 1054 self._last_checkpoints = last_checkpoints_with_time 1055 1056 def recover_last_checkpoints(self, checkpoint_paths): 1057 """Recovers the internal saver state after a crash. 1058 1059 This method is useful for recovering the "self._last_checkpoints" state. 1060 1061 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files 1062 exist, use their mtime as the checkpoint timestamp. 1063 1064 Args: 1065 checkpoint_paths: a list of checkpoint paths. 1066 """ 1067 checkpoints_with_mtimes = [] 1068 for checkpoint_path in checkpoint_paths: 1069 mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path]) 1070 if mtime: 1071 checkpoints_with_mtimes.append((checkpoint_path, mtime[0])) 1072 self.set_last_checkpoints_with_time(checkpoints_with_mtimes) 1073 1074 def save(self, 1075 sess, 1076 save_path, 1077 global_step=None, 1078 latest_filename=None, 1079 meta_graph_suffix="meta", 1080 write_meta_graph=True, 1081 write_state=True, 1082 strip_default_attrs=False, 1083 save_debug_info=False): 1084 # pylint: disable=line-too-long 1085 """Saves variables. 1086 1087 This method runs the ops added by the constructor for saving variables. 1088 It requires a session in which the graph was launched. The variables to 1089 save must also have been initialized. 1090 1091 The method returns the path prefix of the newly created checkpoint files. 1092 This string can be passed directly to a call to `restore()`. 1093 1094 Args: 1095 sess: A Session to use to save the variables. 1096 save_path: String. Prefix of filenames created for the checkpoint. 1097 global_step: If provided the global step number is appended to `save_path` 1098 to create the checkpoint filenames. The optional argument can be a 1099 `Tensor`, a `Tensor` name or an integer. 1100 latest_filename: Optional name for the protocol buffer file that will 1101 contains the list of most recent checkpoints. That file, kept in the 1102 same directory as the checkpoint files, is automatically managed by the 1103 saver to keep track of recent checkpoints. Defaults to 'checkpoint'. 1104 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1105 write_meta_graph: `Boolean` indicating whether or not to write the meta 1106 graph file. 1107 write_state: `Boolean` indicating whether or not to write the 1108 `CheckpointStateProto`. 1109 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1110 removed from the NodeDefs. For a detailed guide, see 1111 [Stripping Default-Valued 1112 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1113 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1114 which in the same directory of save_path and with `_debug` added before 1115 the file extension. This is only enabled when `write_meta_graph` is 1116 `True` 1117 1118 Returns: 1119 A string: path prefix used for the checkpoint files. If the saver is 1120 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' 1121 is the number of shards created. 1122 If the saver is empty, returns None. 1123 1124 Raises: 1125 TypeError: If `sess` is not a `Session`. 1126 ValueError: If `latest_filename` contains path components, or if it 1127 collides with `save_path`. 1128 RuntimeError: If save and restore ops weren't built. 1129 """ 1130 # pylint: enable=line-too-long 1131 if not self._is_built and not context.executing_eagerly(): 1132 raise RuntimeError( 1133 "`build()` should be called before save if defer_build==True") 1134 if latest_filename is None: 1135 latest_filename = "checkpoint" 1136 if self._write_version != saver_pb2.SaverDef.V2: 1137 logging.warning("*******************************************************") 1138 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") 1139 logging.warning("Consider switching to the more efficient V2 format:") 1140 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") 1141 logging.warning("now on by default.") 1142 logging.warning("*******************************************************") 1143 1144 if os.path.split(latest_filename)[0]: 1145 raise ValueError("'latest_filename' must not contain path components") 1146 1147 if global_step is not None: 1148 if not isinstance(global_step, compat.integral_types): 1149 global_step = training_util.global_step(sess, global_step) 1150 checkpoint_file = "%s-%d" % (save_path, global_step) 1151 if self._pad_step_number: 1152 # Zero-pads the step numbers, so that they are sorted when listed. 1153 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) 1154 else: 1155 checkpoint_file = save_path 1156 if os.path.basename(save_path) == latest_filename and not self._sharded: 1157 # Guard against collision between data file and checkpoint state file. 1158 raise ValueError( 1159 "'latest_filename' collides with 'save_path': '%s' and '%s'" % 1160 (latest_filename, save_path)) 1161 1162 if (not context.executing_eagerly() and 1163 not isinstance(sess, session.SessionInterface)): 1164 raise TypeError("'sess' must be a Session; %s" % sess) 1165 1166 save_path_parent = os.path.dirname(save_path) 1167 if not self._is_empty: 1168 try: 1169 if context.executing_eagerly(): 1170 self._build_eager( 1171 checkpoint_file, build_save=True, build_restore=False) 1172 model_checkpoint_path = self.saver_def.save_tensor_name 1173 else: 1174 model_checkpoint_path = sess.run( 1175 self.saver_def.save_tensor_name, 1176 {self.saver_def.filename_tensor_name: checkpoint_file}) 1177 1178 model_checkpoint_path = compat.as_str(model_checkpoint_path) 1179 if write_state: 1180 self._RecordLastCheckpoint(model_checkpoint_path) 1181 checkpoint_management.update_checkpoint_state_internal( 1182 save_dir=save_path_parent, 1183 model_checkpoint_path=model_checkpoint_path, 1184 all_model_checkpoint_paths=self.last_checkpoints, 1185 latest_filename=latest_filename, 1186 save_relative_paths=self._save_relative_paths) 1187 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) 1188 except (errors.FailedPreconditionError, errors.NotFoundError) as exc: 1189 if not gfile.IsDirectory(save_path_parent): 1190 exc = ValueError( 1191 "Parent directory of {} doesn't exist, can't save.".format( 1192 save_path)) 1193 raise exc 1194 1195 if write_meta_graph: 1196 meta_graph_filename = checkpoint_management.meta_graph_filename( 1197 checkpoint_file, meta_graph_suffix=meta_graph_suffix) 1198 if not context.executing_eagerly(): 1199 with sess.graph.as_default(): 1200 self.export_meta_graph( 1201 meta_graph_filename, 1202 strip_default_attrs=strip_default_attrs, 1203 save_debug_info=save_debug_info) 1204 1205 if self._is_empty: 1206 return None 1207 else: 1208 return model_checkpoint_path 1209 1210 def export_meta_graph(self, 1211 filename=None, 1212 collection_list=None, 1213 as_text=False, 1214 export_scope=None, 1215 clear_devices=False, 1216 clear_extraneous_savers=False, 1217 strip_default_attrs=False, 1218 save_debug_info=False): 1219 # pylint: disable=line-too-long 1220 """Writes `MetaGraphDef` to save_path/filename. 1221 1222 Args: 1223 filename: Optional meta_graph filename including the path. 1224 collection_list: List of string keys to collect. 1225 as_text: If `True`, writes the meta_graph as an ASCII proto. 1226 export_scope: Optional `string`. Name scope to remove. 1227 clear_devices: Whether or not to clear the device field for an `Operation` 1228 or `Tensor` during export. 1229 clear_extraneous_savers: Remove any Saver-related information from the 1230 graph (both Save/Restore ops and SaverDefs) that are not associated with 1231 this Saver. 1232 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1233 removed from the NodeDefs. For a detailed guide, see 1234 [Stripping Default-Valued 1235 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1236 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1237 which in the same directory of filename and with `_debug` added before 1238 the file extension. 1239 1240 Returns: 1241 A `MetaGraphDef` proto. 1242 """ 1243 # pylint: enable=line-too-long 1244 return export_meta_graph( 1245 filename=filename, 1246 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 1247 saver_def=self.saver_def, 1248 collection_list=collection_list, 1249 as_text=as_text, 1250 export_scope=export_scope, 1251 clear_devices=clear_devices, 1252 clear_extraneous_savers=clear_extraneous_savers, 1253 strip_default_attrs=strip_default_attrs, 1254 save_debug_info=save_debug_info) 1255 1256 def restore(self, sess, save_path): 1257 """Restores previously saved variables. 1258 1259 This method runs the ops added by the constructor for restoring variables. 1260 It requires a session in which the graph was launched. The variables to 1261 restore do not have to have been initialized, as restoring is itself a way 1262 to initialize variables. 1263 1264 The `save_path` argument is typically a value previously returned from a 1265 `save()` call, or a call to `latest_checkpoint()`. 1266 1267 Args: 1268 sess: A `Session` to use to restore the parameters. None in eager mode. 1269 save_path: Path where parameters were previously saved. 1270 1271 Raises: 1272 ValueError: If save_path is None or not a valid checkpoint. 1273 """ 1274 if self._is_empty: 1275 return 1276 if save_path is None: 1277 raise ValueError("Can't load save_path when it is None.") 1278 1279 checkpoint_prefix = compat.as_text(save_path) 1280 if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): 1281 raise ValueError("The passed save_path is not a valid checkpoint: " + 1282 checkpoint_prefix) 1283 1284 logging.info("Restoring parameters from %s", checkpoint_prefix) 1285 try: 1286 if context.executing_eagerly(): 1287 self._build_eager(save_path, build_save=False, build_restore=True) 1288 else: 1289 sess.run(self.saver_def.restore_op_name, 1290 {self.saver_def.filename_tensor_name: save_path}) 1291 except errors.NotFoundError as err: 1292 # There are three common conditions that might cause this error: 1293 # 0. The file is missing. We ignore here, as this is checked above. 1294 # 1. This is an object-based checkpoint trying name-based loading. 1295 # 2. The graph has been altered and a variable or other name is missing. 1296 1297 # 1. The checkpoint would not be loaded successfully as is. Try to parse 1298 # it as an object-based checkpoint. 1299 try: 1300 names_to_keys = object_graph_key_mapping(save_path) 1301 except errors.NotFoundError: 1302 # 2. This is not an object-based checkpoint, which likely means there 1303 # is a graph mismatch. Re-raise the original error with 1304 # a helpful message (b/110263146) 1305 raise _wrap_restore_error_with_msg( 1306 err, "a Variable name or other graph key that is missing") 1307 1308 # This is an object-based checkpoint. We'll print a warning and then do 1309 # the restore. 1310 logging.warning( 1311 "Restoring an object-based checkpoint using a name-based saver. This " 1312 "may be somewhat fragile, and will re-build the Saver. Instead, " 1313 "consider loading object-based checkpoints using " 1314 "tf.train.Checkpoint().") 1315 self._object_restore_saver = saver_from_object_based_checkpoint( 1316 checkpoint_path=save_path, 1317 var_list=self._var_list, 1318 builder=self._builder, 1319 names_to_keys=names_to_keys, 1320 cached_saver=self._object_restore_saver) 1321 self._object_restore_saver.restore(sess=sess, save_path=save_path) 1322 except errors.InvalidArgumentError as err: 1323 # There is a mismatch between the graph and the checkpoint being loaded. 1324 # We add a more reasonable error message here to help users (b/110263146) 1325 raise _wrap_restore_error_with_msg( 1326 err, "a mismatch between the current graph and the graph") 1327 1328 @staticmethod 1329 def _add_collection_def(meta_graph_def, key, export_scope=None): 1330 """Adds a collection to MetaGraphDef protocol buffer. 1331 1332 Args: 1333 meta_graph_def: MetaGraphDef protocol buffer. 1334 key: One of the GraphKeys or user-defined string. 1335 export_scope: Optional `string`. Name scope to remove. 1336 """ 1337 meta_graph.add_collection_def( 1338 meta_graph_def, key, export_scope=export_scope) 1339 1340 1341@tf_export(v1=["train.import_meta_graph"]) 1342def import_meta_graph(meta_graph_or_file, 1343 clear_devices=False, 1344 import_scope=None, 1345 **kwargs): 1346 """Recreates a Graph saved in a `MetaGraphDef` proto. 1347 1348 This function takes a `MetaGraphDef` protocol buffer as input. If 1349 the argument is a file containing a `MetaGraphDef` protocol buffer , 1350 it constructs a protocol buffer from the file content. The function 1351 then adds all the nodes from the `graph_def` field to the 1352 current graph, recreates all the collections, and returns a saver 1353 constructed from the `saver_def` field. 1354 1355 In combination with `export_meta_graph()`, this function can be used to 1356 1357 * Serialize a graph along with other Python objects such as `QueueRunner`, 1358 `Variable` into a `MetaGraphDef`. 1359 1360 * Restart training from a saved graph and checkpoints. 1361 1362 * Run inference from a saved graph and checkpoints. 1363 1364 ```Python 1365 ... 1366 # Create a saver. 1367 saver = tf.compat.v1.train.Saver(...variables...) 1368 # Remember the training_op we want to run by adding it to a collection. 1369 tf.compat.v1.add_to_collection('train_op', train_op) 1370 sess = tf.compat.v1.Session() 1371 for step in xrange(1000000): 1372 sess.run(train_op) 1373 if step % 1000 == 0: 1374 # Saves checkpoint, which by default also exports a meta_graph 1375 # named 'my-model-global_step.meta'. 1376 saver.save(sess, 'my-model', global_step=step) 1377 ``` 1378 1379 Later we can continue training from this saved `meta_graph` without building 1380 the model from scratch. 1381 1382 ```Python 1383 with tf.Session() as sess: 1384 new_saver = 1385 tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 1386 new_saver.restore(sess, 'my-save-dir/my-model-10000') 1387 # tf.get_collection() returns a list. In this example we only want 1388 # the first one. 1389 train_op = tf.get_collection('train_op')[0] 1390 for step in xrange(1000000): 1391 sess.run(train_op) 1392 ``` 1393 1394 NOTE: Restarting training from saved `meta_graph` only works if the 1395 device assignments have not changed. 1396 1397 Example: 1398 Variables, placeholders, and independent operations can also be stored, as 1399 shown in the following example. 1400 1401 ```Python 1402 # Saving contents and operations. 1403 v1 = tf.placeholder(tf.float32, name="v1") 1404 v2 = tf.placeholder(tf.float32, name="v2") 1405 v3 = tf.math.multiply(v1, v2) 1406 vx = tf.Variable(10.0, name="vx") 1407 v4 = tf.add(v3, vx, name="v4") 1408 saver = tf.train.Saver([vx]) 1409 sess = tf.Session() 1410 sess.run(tf.global_variables_initializer()) 1411 sess.run(vx.assign(tf.add(vx, vx))) 1412 result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) 1413 print(result) 1414 saver.save(sess, "./model_ex1") 1415 ``` 1416 1417 Later this model can be restored and contents loaded. 1418 1419 ```Python 1420 # Restoring variables and running operations. 1421 saver = tf.train.import_meta_graph("./model_ex1.meta") 1422 sess = tf.Session() 1423 saver.restore(sess, "./model_ex1") 1424 result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 1425 print(result) 1426 ``` 1427 1428 Args: 1429 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 1430 the path) containing a `MetaGraphDef`. 1431 clear_devices: Whether or not to clear the device field for an `Operation` 1432 or `Tensor` during import. 1433 import_scope: Optional `string`. Name scope to add. Only used when 1434 initializing from protocol buffer. 1435 **kwargs: Optional keyed arguments. 1436 1437 Returns: 1438 A saver constructed from `saver_def` in `MetaGraphDef` or None. 1439 1440 A None value is returned if no variables exist in the `MetaGraphDef` 1441 (i.e., there are no variables to restore). 1442 1443 Raises: 1444 RuntimeError: If called with eager execution enabled. 1445 1446 @compatibility(eager) 1447 Exporting/importing meta graphs is not supported. No graph exists when eager 1448 execution is enabled. 1449 @end_compatibility 1450 """ # pylint: disable=g-doc-exception 1451 return _import_meta_graph_with_return_elements(meta_graph_or_file, 1452 clear_devices, import_scope, 1453 **kwargs)[0] 1454 1455 1456def _import_meta_graph_with_return_elements(meta_graph_or_file, 1457 clear_devices=False, 1458 import_scope=None, 1459 return_elements=None, 1460 **kwargs): 1461 """Import MetaGraph, and return both a saver and returned elements.""" 1462 if context.executing_eagerly(): 1463 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1464 "eager execution is enabled. No graph exists when eager " 1465 "execution is enabled.") 1466 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 1467 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) 1468 else: 1469 meta_graph_def = meta_graph_or_file 1470 1471 imported_vars, imported_return_elements = ( 1472 meta_graph.import_scoped_meta_graph_with_return_elements( 1473 meta_graph_def, 1474 clear_devices=clear_devices, 1475 import_scope=import_scope, 1476 return_elements=return_elements, 1477 **kwargs)) 1478 1479 saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1480 imported_vars) 1481 return saver, imported_return_elements 1482 1483 1484def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1485 imported_vars): 1486 """Return a saver for restoring variable values to an imported MetaGraph.""" 1487 if meta_graph_def.HasField("saver_def"): 1488 # Infer the scope that is prepended by `import_scoped_meta_graph`. 1489 scope = import_scope 1490 var_names = list(imported_vars.keys()) 1491 if var_names: 1492 sample_key = var_names[0] 1493 sample_var = imported_vars[sample_key] 1494 scope = sample_var.name[:-len(sample_key)] 1495 1496 return Saver(saver_def=meta_graph_def.saver_def, name=scope) 1497 else: 1498 if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access 1499 # Return the default saver instance for all graph variables. 1500 return Saver() 1501 else: 1502 # If no graph variables exist, then a Saver cannot be constructed. 1503 logging.info("Saver not created because there are no variables in the" 1504 " graph to restore") 1505 return None 1506 1507 1508@tf_export(v1=["train.export_meta_graph"]) 1509def export_meta_graph(filename=None, 1510 meta_info_def=None, 1511 graph_def=None, 1512 saver_def=None, 1513 collection_list=None, 1514 as_text=False, 1515 graph=None, 1516 export_scope=None, 1517 clear_devices=False, 1518 clear_extraneous_savers=False, 1519 strip_default_attrs=False, 1520 save_debug_info=False, 1521 **kwargs): 1522 # pylint: disable=line-too-long 1523 """Returns `MetaGraphDef` proto. 1524 1525 Optionally writes it to filename. 1526 1527 This function exports the graph, saver, and collection objects into 1528 `MetaGraphDef` protocol buffer with the intention of it being imported 1529 at a later time or location to restart training, run inference, or be 1530 a subgraph. 1531 1532 Args: 1533 filename: Optional filename including the path for writing the generated 1534 `MetaGraphDef` protocol buffer. 1535 meta_info_def: `MetaInfoDef` protocol buffer. 1536 graph_def: `GraphDef` protocol buffer. 1537 saver_def: `SaverDef` protocol buffer. 1538 collection_list: List of string keys to collect. 1539 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 1540 graph: The `Graph` to export. If `None`, use the default graph. 1541 export_scope: Optional `string`. Name scope under which to extract the 1542 subgraph. The scope name will be striped from the node definitions for 1543 easy import later into new name scopes. If `None`, the whole graph is 1544 exported. graph_def and export_scope cannot both be specified. 1545 clear_devices: Whether or not to clear the device field for an `Operation` 1546 or `Tensor` during export. 1547 clear_extraneous_savers: Remove any Saver-related information from the graph 1548 (both Save/Restore ops and SaverDefs) that are not associated with the 1549 provided SaverDef. 1550 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1551 removed from the NodeDefs. For a detailed guide, see 1552 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1553 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1554 which in the same directory of filename and with `_debug` added before the 1555 file extend. 1556 **kwargs: Optional keyed arguments. 1557 1558 Returns: 1559 A `MetaGraphDef` proto. 1560 1561 Raises: 1562 ValueError: When the `GraphDef` is larger than 2GB. 1563 RuntimeError: If called with eager execution enabled. 1564 1565 @compatibility(eager) 1566 Exporting/importing meta graphs is not supported unless both `graph_def` and 1567 `graph` are provided. No graph exists when eager execution is enabled. 1568 @end_compatibility 1569 """ 1570 # pylint: enable=line-too-long 1571 if context.executing_eagerly() and not (graph_def is not None and 1572 graph is not None): 1573 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1574 "eager execution is enabled. No graph exists when eager " 1575 "execution is enabled.") 1576 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 1577 filename=filename, 1578 meta_info_def=meta_info_def, 1579 graph_def=graph_def, 1580 saver_def=saver_def, 1581 collection_list=collection_list, 1582 as_text=as_text, 1583 graph=graph, 1584 export_scope=export_scope, 1585 clear_devices=clear_devices, 1586 clear_extraneous_savers=clear_extraneous_savers, 1587 strip_default_attrs=strip_default_attrs, 1588 save_debug_info=save_debug_info, 1589 **kwargs) 1590 return meta_graph_def 1591 1592 1593def _wrap_restore_error_with_msg(err, extra_verbiage): 1594 err_msg = ("Restoring from checkpoint failed. This is most likely " 1595 "due to {} from the checkpoint. Please ensure that you " 1596 "have not altered the graph expected based on the checkpoint. " 1597 "Original error:\n\n{}").format(extra_verbiage, err.message) 1598 return err.__class__(err.node_def, err.op, err_msg) 1599 1600 1601ops.register_proto_function( 1602 ops.GraphKeys.SAVERS, 1603 proto_type=saver_pb2.SaverDef, 1604 to_proto=Saver.to_proto, 1605 from_proto=Saver.from_proto) 1606 1607 1608def object_graph_key_mapping(checkpoint_path): 1609 """Return name to key mappings from the checkpoint. 1610 1611 Args: 1612 checkpoint_path: string, path to object-based checkpoint 1613 1614 Returns: 1615 Dictionary mapping tensor names to checkpoint keys. 1616 """ 1617 reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) 1618 object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY) 1619 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1620 object_graph_proto.ParseFromString(object_graph_string) 1621 names_to_keys = {} 1622 for node in object_graph_proto.nodes: 1623 for attribute in node.attributes: 1624 names_to_keys[attribute.full_name] = attribute.checkpoint_key 1625 return names_to_keys 1626 1627 1628def saver_from_object_based_checkpoint(checkpoint_path, 1629 var_list=None, 1630 builder=None, 1631 names_to_keys=None, 1632 cached_saver=None): 1633 """Return a `Saver` which reads from an object-based checkpoint. 1634 1635 This function validates that all variables in the variables list are remapped 1636 in the object-based checkpoint (or `names_to_keys` dict if provided). A 1637 saver will be created with the list of remapped variables. 1638 1639 The `cached_saver` argument allows the user to pass in a previously created 1640 saver, so multiple `saver.restore()` calls don't pollute the graph when graph 1641 building. This assumes that keys are consistent, meaning that the 1642 1) `checkpoint_path` checkpoint, and 1643 2) checkpoint used to create the `cached_saver` 1644 are the same type of object-based checkpoint. If this argument is set, this 1645 function will simply validate that all variables have been remapped by the 1646 checkpoint at `checkpoint_path`. 1647 1648 Note that in general, `tf.train.Checkpoint` should be used to restore/save an 1649 object-based checkpoint. 1650 1651 Args: 1652 checkpoint_path: string, path to object-based checkpoint 1653 var_list: list of `Variables` that appear in the checkpoint. If `None`, 1654 `var_list` will be set to all saveable objects. 1655 builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder` 1656 will be created. 1657 names_to_keys: dict mapping string tensor names to checkpooint keys. If 1658 `None`, this dict will be generated from the checkpoint file. 1659 cached_saver: Cached `Saver` object with remapped variables. 1660 1661 Returns: 1662 `Saver` with remapped variables for reading from an object-based checkpoint. 1663 1664 Raises: 1665 ValueError if the checkpoint provided is not an object-based checkpoint. 1666 NotFoundError: If one of the variables in `var_list` can not be found in the 1667 checkpoint. This could mean the checkpoint or `names_to_keys` mapping is 1668 missing the variable. 1669 """ 1670 if names_to_keys is None: 1671 try: 1672 names_to_keys = object_graph_key_mapping(checkpoint_path) 1673 except errors.NotFoundError: 1674 raise ValueError("Checkpoint in %s not an object-based checkpoint." % 1675 checkpoint_path) 1676 if var_list is None: 1677 var_list = variables._all_saveable_objects() # pylint: disable=protected-access 1678 if builder is None: 1679 builder = BulkSaverBuilder() 1680 1681 saveables = saveable_object_util.validate_and_slice_inputs(var_list) 1682 current_names = set() 1683 for saveable in saveables: 1684 for spec in saveable.specs: 1685 current_names.add(spec.name) 1686 previous_names = set(names_to_keys.keys()) 1687 missing_names = current_names - previous_names 1688 if missing_names: 1689 extra_names = previous_names - current_names 1690 intersecting_names = previous_names.intersection(current_names) 1691 raise errors.NotFoundError( 1692 None, 1693 None, 1694 message=( 1695 "\n\nExisting variables not in the checkpoint: %s\n\n" 1696 "Variables names when this checkpoint was written which don't " 1697 "exist now: %s\n\n" 1698 "(%d variable name(s) did match)\n\n" 1699 "Could not find some variables in the checkpoint (see names " 1700 "above). Saver was attempting to load an object-based checkpoint " 1701 "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) " 1702 "using variable names. If the checkpoint was written with eager " 1703 "execution enabled, it's possible that variable names have " 1704 "changed (for example missing a '_1' suffix). It's also " 1705 "possible that there are new variables which did not exist " 1706 "when the checkpoint was written. You can construct a " 1707 "Saver(var_list=...) with only the variables which previously " 1708 "existed, and if variable names have changed you may need to " 1709 "make this a dictionary with the old names as keys. If you're " 1710 "using an Estimator, you'll need to return a tf.train.Saver " 1711 "inside a tf.train.Scaffold from your model_fn.") % 1712 (", ".join(sorted(missing_names)), ", ".join( 1713 sorted(extra_names)), len(intersecting_names))) 1714 for saveable in saveables: 1715 for spec in saveable.specs: 1716 spec.name = names_to_keys[spec.name] 1717 if cached_saver is None: 1718 return Saver(saveables) 1719 return cached_saver 1720