1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables. 18 19Symbols in this file are deprecated. See replacements in 20tensorflow/python/training/trackable and tensorflow/python/training/saving. 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27import os.path 28import time 29import uuid 30 31import numpy as np 32from tensorflow.core.protobuf import meta_graph_pb2 33from tensorflow.core.protobuf import saver_pb2 34from tensorflow.core.protobuf import trackable_object_graph_pb2 35from tensorflow.python import pywrap_tensorflow 36from tensorflow.python.client import session 37from tensorflow.python.eager import context 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import device as pydev 40from tensorflow.python.framework import errors 41from tensorflow.python.framework import meta_graph 42from tensorflow.python.framework import ops 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import gen_io_ops 46from tensorflow.python.ops import io_ops 47from tensorflow.python.ops import string_ops 48from tensorflow.python.ops import variables 49from tensorflow.python.platform import gfile 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.training import checkpoint_management 52from tensorflow.python.training import training_util 53from tensorflow.python.training.saving import saveable_object 54from tensorflow.python.training.saving import saveable_object_util 55from tensorflow.python.training.tracking import base as trackable 56from tensorflow.python.util import compat 57from tensorflow.python.util.tf_export import tf_export 58 59 60# TODO(allenl): Remove these aliases once all users are migrated off. 61get_checkpoint_state = checkpoint_management.get_checkpoint_state 62update_checkpoint_state = checkpoint_management.update_checkpoint_state 63generate_checkpoint_state_proto = ( 64 checkpoint_management.generate_checkpoint_state_proto) 65latest_checkpoint = checkpoint_management.latest_checkpoint 66checkpoint_exists = checkpoint_management.checkpoint_exists 67get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes 68remove_checkpoint = checkpoint_management.remove_checkpoint 69 70 71class BaseSaverBuilder(object): 72 """Base class for Savers. 73 74 Can be extended to create different Ops. 75 """ 76 77 SaveSpec = saveable_object.SaveSpec 78 SaveableObject = saveable_object.SaveableObject 79 80 # Aliases for code which was moved but still has lots of users. 81 VariableSaveable = saveable_object_util.ReferenceVariableSaveable 82 ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable 83 84 def __init__(self, write_version=saver_pb2.SaverDef.V2): 85 self._write_version = write_version 86 87 def save_op(self, filename_tensor, saveables): 88 """Create an Op to save 'saveables'. 89 90 This is intended to be overridden by subclasses that want to generate 91 different Ops. 92 93 Args: 94 filename_tensor: String Tensor. 95 saveables: A list of BaseSaverBuilder.SaveableObject objects. 96 97 Returns: 98 An Operation that save the variables. 99 100 Raises: 101 RuntimeError: (implementation detail) if "self._write_version" is an 102 unexpected value. 103 """ 104 # pylint: disable=protected-access 105 tensor_names = [] 106 tensors = [] 107 tensor_slices = [] 108 for saveable in saveables: 109 for spec in saveable.specs: 110 tensor_names.append(spec.name) 111 tensors.append(spec.tensor) 112 tensor_slices.append(spec.slice_spec) 113 if self._write_version == saver_pb2.SaverDef.V1: 114 return io_ops._save( 115 filename=filename_tensor, 116 tensor_names=tensor_names, 117 tensors=tensors, 118 tensor_slices=tensor_slices) 119 elif self._write_version == saver_pb2.SaverDef.V2: 120 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix 121 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". 122 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, 123 tensors) 124 else: 125 raise RuntimeError("Unexpected write_version: " + self._write_version) 126 127 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 128 restore_sequentially): 129 """Restore all tensors contained in saveables. 130 131 By default, this issues separate calls to `restore_op` for each saveable. 132 Subclasses may override to load multiple saveables in a single call. 133 134 Args: 135 filename_tensor: String Tensor. 136 saveables: List of BaseSaverBuilder.SaveableObject objects. 137 preferred_shard: Int. Shard to open first when loading a sharded file. 138 restore_sequentially: Unused. Bool. If true, each restore is sequential. 139 140 Returns: 141 A list of Tensors resulting from reading 'saveable' from 142 'filename'. 143 144 """ 145 del restore_sequentially 146 all_tensors = [] 147 for saveable in saveables: 148 if saveable.device: 149 device = saveable_object_util.set_cpu0(saveable.device) 150 else: 151 device = None 152 with ops.device(device): 153 all_tensors.extend( 154 self.restore_op(filename_tensor, saveable, preferred_shard)) 155 return all_tensors 156 157 # pylint: disable=unused-argument 158 def restore_op(self, filename_tensor, saveable, preferred_shard): 159 """Create ops to restore 'saveable'. 160 161 This is intended to be overridden by subclasses that want to generate 162 different Ops. 163 164 Args: 165 filename_tensor: String Tensor. 166 saveable: A BaseSaverBuilder.SaveableObject object. 167 preferred_shard: Int. Shard to open first when loading a sharded file. 168 169 Returns: 170 A list of Tensors resulting from reading 'saveable' from 171 'filename'. 172 """ 173 # pylint: disable=protected-access 174 tensors = [] 175 for spec in saveable.specs: 176 tensors.append( 177 io_ops.restore_v2( 178 filename_tensor, 179 [spec.name], 180 [spec.slice_spec], 181 [spec.dtype])[0]) 182 183 return tensors 184 # pylint: enable=unused-argument 185 186 def sharded_filename(self, filename_tensor, shard, num_shards): 187 """Append sharding information to a filename. 188 189 Args: 190 filename_tensor: A string tensor. 191 shard: Integer. The shard for the filename. 192 num_shards: An int Tensor for the number of shards. 193 194 Returns: 195 A string tensor. 196 """ 197 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 198 199 def _AddSaveOps(self, filename_tensor, saveables): 200 """Add ops to save variables that are on the same shard. 201 202 Args: 203 filename_tensor: String Tensor. 204 saveables: A list of SaveableObject objects. 205 206 Returns: 207 A tensor with the filename used to save. 208 """ 209 save = self.save_op(filename_tensor, saveables) 210 return control_flow_ops.with_dependencies([save], filename_tensor) 211 212 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): 213 """Add ops to save the params per shard, for the V2 format. 214 215 Note that the sharded save procedure for the V2 format is different from 216 V1: there is a special "merge" step that merges the small metadata produced 217 from each device. 218 219 Args: 220 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A 221 FILENAME*, but as a prefix of a V2 checkpoint; 222 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as 223 returned by _GroupByDevices(). 224 225 Returns: 226 An op to save the variables, which, when evaluated, returns the prefix 227 "<user-fed prefix>" only and does not include the sharded spec suffix. 228 """ 229 # IMPLEMENTATION DETAILS: most clients should skip. 230 # 231 # Suffix for any well-formed "checkpoint_prefix", when sharded. 232 # Transformations: 233 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 234 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>. 235 # 236 # Example: 237 # During runtime, a temporary directory is first created, which contains 238 # files 239 # 240 # <train dir>/myckpt_temp/ 241 # part-?????-of-?????{.index, .data-00000-of-00001} 242 # 243 # Before .save() finishes, they will be (hopefully, atomically) renamed to 244 # 245 # <train dir>/ 246 # myckpt{.index, .data-?????-of-?????} 247 # 248 # Users only need to interact with the user-specified prefix, which is 249 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 250 # prefix directly, instead of any physical pathname. (On failure and 251 # subsequent restore, an outdated and orphaned temporary directory can be 252 # safely removed.) 253 _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex 254 tmp_checkpoint_prefix = string_ops.string_join( 255 [checkpoint_prefix, _SHARDED_SUFFIX]) 256 257 num_shards = len(per_device) 258 sharded_saves = [] 259 sharded_prefixes = [] 260 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 261 last_device = None 262 for shard, (device, saveables) in enumerate(per_device): 263 last_device = device 264 with ops.device(saveable_object_util.set_cpu0(device)): 265 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, 266 num_shards_tensor) 267 sharded_prefixes.append(sharded_filename) 268 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 269 270 with ops.control_dependencies([x.op for x in sharded_saves]): 271 # Co-locates the merge step with the last device. 272 with ops.device(saveable_object_util.set_cpu0(last_device)): 273 # V2 format write path consists of a metadata merge step. Once merged, 274 # attempts to delete the temporary directory, "<user-fed prefix>_temp". 275 merge_step = gen_io_ops.merge_v2_checkpoints( 276 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) 277 with ops.control_dependencies([merge_step]): 278 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the 279 # sharded spec suffix. 280 return array_ops.identity(checkpoint_prefix) 281 282 def _AddShardedSaveOps(self, filename_tensor, per_device): 283 """Add ops to save the params per shard. 284 285 Args: 286 filename_tensor: a scalar String Tensor. 287 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as 288 returned by _GroupByDevices(). 289 290 Returns: 291 An op to save the variables. 292 """ 293 if self._write_version == saver_pb2.SaverDef.V2: 294 return self._AddShardedSaveOpsForV2(filename_tensor, per_device) 295 296 num_shards = len(per_device) 297 sharded_saves = [] 298 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 299 for shard, (device, saveables) in enumerate(per_device): 300 with ops.device(device): 301 sharded_filename = self.sharded_filename(filename_tensor, shard, 302 num_shards_tensor) 303 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 304 # Return the sharded name for the save path. 305 with ops.control_dependencies([x.op for x in sharded_saves]): 306 return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor) 307 308 def _AddRestoreOps(self, 309 filename_tensor, 310 saveables, 311 restore_sequentially, 312 reshape, 313 preferred_shard=-1, 314 name="restore_all"): 315 """Add operations to restore saveables. 316 317 Args: 318 filename_tensor: Tensor for the path of the file to load. 319 saveables: A list of SaveableObject objects. 320 restore_sequentially: True if we want to restore variables sequentially 321 within a shard. 322 reshape: True if we want to reshape loaded tensors to the shape of 323 the corresponding variable. 324 preferred_shard: Shard to open first when loading a sharded file. 325 name: Name for the returned op. 326 327 Returns: 328 An Operation that restores the variables. 329 """ 330 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, 331 restore_sequentially) 332 333 assign_ops = [] 334 idx = 0 335 # Load and optionally reshape on the CPU, as string tensors are not 336 # available on the GPU. 337 # TODO(touts): Re-enable restore on GPU when we can support annotating 338 # string tensors as "HostMemory" inputs. 339 for saveable in saveables: 340 shapes = None 341 if reshape: 342 # Compute the shapes, let the restore op decide if and how to do 343 # the reshape. 344 shapes = [] 345 for spec in saveable.specs: 346 v = spec.tensor 347 shape = v.get_shape() 348 if not shape.is_fully_defined(): 349 shape = array_ops.shape(v) 350 shapes.append(shape) 351 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] 352 idx += len(saveable.specs) 353 assign_ops.append(saveable.restore(saveable_tensors, shapes)) 354 355 # Create a Noop that has control dependencies from all the updates. 356 return control_flow_ops.group(*assign_ops, name=name) 357 358 def _AddShardedRestoreOps(self, filename_tensor, per_device, 359 restore_sequentially, reshape): 360 """Add Ops to restore variables from multiple devices. 361 362 Args: 363 filename_tensor: Tensor for the path of the file to load. 364 per_device: A list of (device, SaveableObject) pairs, as 365 returned by _GroupByDevices(). 366 restore_sequentially: True if we want to restore variables sequentially 367 within a shard. 368 reshape: True if we want to reshape loaded tensors to the shape of 369 the corresponding variable. 370 371 Returns: 372 An Operation that restores the variables. 373 """ 374 sharded_restores = [] 375 for shard, (device, saveables) in enumerate(per_device): 376 with ops.device(device): 377 sharded_restores.append( 378 self._AddRestoreOps( 379 filename_tensor, 380 saveables, 381 restore_sequentially, 382 reshape, 383 preferred_shard=shard, 384 name="restore_shard")) 385 return control_flow_ops.group(*sharded_restores, name="restore_all") 386 387 def _GroupByDevices(self, saveables): 388 """Group Variable tensor slices per device. 389 390 TODO(touts): Make sure that all the devices found are on different 391 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. 392 It can happen if the devices are unspecified. 393 394 Args: 395 saveables: A list of BaseSaverBuilder.SaveableObject objects. 396 397 Returns: 398 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples. 399 The list is sorted by ascending device_name. 400 401 Raises: 402 ValueError: If the tensors of a saveable are on different devices. 403 """ 404 per_device = collections.defaultdict(lambda: []) 405 for saveable in saveables: 406 canonical_device = set( 407 pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) 408 if len(canonical_device) != 1: 409 raise ValueError("All tensors of a saveable object must be " 410 "on the same device: %s" % saveable.name) 411 per_device[canonical_device.pop()].append(saveable) 412 return sorted(per_device.items(), key=lambda t: t[0]) 413 414 def build(self, 415 names_to_saveables, 416 reshape=False, 417 sharded=False, 418 max_to_keep=5, 419 keep_checkpoint_every_n_hours=10000.0, 420 name=None, 421 restore_sequentially=False, 422 filename="model"): 423 """Builds save/restore graph nodes or runs save/restore in eager mode. 424 425 Args: 426 names_to_saveables: A dictionary mapping name to a Variable or 427 SaveableObject. Each name will be associated with the 428 corresponding variable in the checkpoint. 429 reshape: If True, allow restoring parameters from a checkpoint 430 that where the parameters have a different shape. This is 431 only needed when you try to restore from a Dist-Belief checkpoint, 432 and only some times. 433 sharded: If True, shard the checkpoints, one per device that has 434 Variable nodes. 435 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints 436 are created, old ones are deleted. If None or 0, no checkpoints are 437 deleted from the filesystem but only the last one is kept in the 438 `checkpoint` file. Presently the number is only roughly enforced. For 439 example in case of restarts more than max_to_keep checkpoints may be 440 kept. 441 keep_checkpoint_every_n_hours: How often checkpoints should be kept. 442 Defaults to 10,000 hours. 443 name: String. Optional name to use as a prefix when adding operations. 444 restore_sequentially: A Bool, which if true, causes restore of different 445 variables to happen sequentially within each device. 446 filename: If known at graph construction time, filename used for variable 447 loading/saving. If None, then the default name "model" will be used. 448 449 Returns: 450 A SaverDef proto. 451 452 Raises: 453 TypeError: If 'names_to_saveables' is not a dictionary mapping string 454 keys to variable Tensors. 455 ValueError: If any of the keys or values in 'names_to_saveables' is not 456 unique. 457 """ 458 return self._build_internal( 459 names_to_saveables=names_to_saveables, 460 reshape=reshape, 461 sharded=sharded, 462 max_to_keep=max_to_keep, 463 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 464 name=name, 465 restore_sequentially=restore_sequentially, 466 filename=filename) 467 468 def _build_internal(self, 469 names_to_saveables, 470 reshape=False, 471 sharded=False, 472 max_to_keep=5, 473 keep_checkpoint_every_n_hours=10000.0, 474 name=None, 475 restore_sequentially=False, 476 filename="model", 477 build_save=True, 478 build_restore=True): 479 """build() with option to only perform save and restore.""" 480 if not context.executing_eagerly() and (not build_save or 481 not build_restore): 482 raise ValueError("save and restore operations need to be built together " 483 " when eager execution is not enabled.") 484 485 saveables = saveable_object_util.validate_and_slice_inputs( 486 names_to_saveables) 487 if max_to_keep is None: 488 max_to_keep = 0 489 490 with ops.name_scope(name, "save", 491 [saveable.op for saveable in saveables]) as name: 492 # Add a placeholder string tensor for the filename. 493 filename_tensor = array_ops.placeholder_with_default( 494 filename or "model", shape=(), name="filename") 495 # Keep the name "Const" for backwards compatibility. 496 filename_tensor = array_ops.placeholder_with_default( 497 filename_tensor, shape=(), name="Const") 498 499 # Add the save ops. 500 if sharded: 501 per_device = self._GroupByDevices(saveables) 502 if build_save: 503 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) 504 if build_restore: 505 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, 506 restore_sequentially, reshape) 507 else: 508 if build_save: 509 save_tensor = self._AddSaveOps(filename_tensor, saveables) 510 if build_restore: 511 restore_op = self._AddRestoreOps(filename_tensor, saveables, 512 restore_sequentially, reshape) 513 514 # In the following use case, it's possible to have restore_ops be called 515 # something else: 516 # - Build inference graph and export a meta_graph. 517 # - Import the inference meta_graph 518 # - Extend the inference graph to a train graph. 519 # - Export a new meta_graph. 520 # Now the second restore_op will be called "restore_all_1". 521 # As such, comment out the assert for now until we know whether supporting 522 # such usage model makes sense. 523 # 524 # assert restore_op.name.endswith("restore_all"), restore_op.name 525 if context.executing_eagerly(): 526 # Store the tensor values to the tensor_names. 527 save_tensor_name = save_tensor.numpy() if build_save else "" 528 return saver_pb2.SaverDef( 529 filename_tensor_name=filename_tensor.numpy(), 530 save_tensor_name=save_tensor_name, 531 restore_op_name="", 532 max_to_keep=max_to_keep, 533 sharded=sharded, 534 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 535 version=self._write_version) 536 else: 537 graph = ops.get_default_graph() 538 # Do some sanity checking on collections containing 539 # PartitionedVariables. If a saved collection has a PartitionedVariable, 540 # the GraphDef needs to include concat ops to get the value (or there'll 541 # be a lookup error on load). 542 check_collection_list = graph.get_all_collection_keys() 543 for collection_type in check_collection_list: 544 for element in graph.get_collection(collection_type): 545 if isinstance(element, variables.PartitionedVariable): 546 try: 547 graph.get_operation_by_name(element.name) 548 except KeyError: 549 # Create a concat op for this PartitionedVariable. The user may 550 # not need it, but we'll try looking it up on MetaGraph restore 551 # since it's in a collection. 552 element.as_tensor() 553 return saver_pb2.SaverDef( 554 filename_tensor_name=filename_tensor.name, 555 save_tensor_name=save_tensor.name, 556 restore_op_name=restore_op.name, 557 max_to_keep=max_to_keep, 558 sharded=sharded, 559 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 560 version=self._write_version) 561 562 563class BulkSaverBuilder(BaseSaverBuilder): 564 """SaverBuilder with support for bulk restoring multiple saveables.""" 565 566 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 567 restore_sequentially): 568 569 # Ignored: bulk restore is internally sequential. 570 del restore_sequentially 571 restore_specs = [] 572 for saveable in saveables: 573 for spec in saveable.specs: 574 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 575 576 names, slices, dtypes = zip(*restore_specs) 577 # Load all tensors onto CPU 0 for compatibility with existing code. 578 with ops.device("cpu:0"): 579 return io_ops.restore_v2(filename_tensor, names, slices, dtypes) 580 581 582def _get_saver_or_default(): 583 """Returns the saver from SAVERS collection, or creates a default one. 584 585 This method is used by other members of the training module, such as 586 `Scaffold`, or `CheckpointSaverHook`. 587 588 Returns: 589 `Saver`. 590 591 Raises: 592 RuntimeError: If the SAVERS collection already has more than one items. 593 """ 594 collection_key = ops.GraphKeys.SAVERS 595 savers = ops.get_collection(collection_key) 596 if savers: 597 if len(savers) > 1: 598 raise RuntimeError( 599 "More than one item in collection {}. " 600 "Please indicate which one to use by passing it to the constructor.". 601 format(collection_key)) 602 return savers[0] 603 saver = Saver(sharded=True, allow_empty=True) 604 if saver is not None: 605 ops.add_to_collection(collection_key, saver) 606 return saver 607 608 609@tf_export(v1=["train.Saver"]) 610class Saver(object): 611 """Saves and restores variables. 612 613 See [Variables](https://tensorflow.org/guide/variables) 614 for an overview of variables, saving and restoring. 615 616 The `Saver` class adds ops to save and restore variables to and from 617 *checkpoints*. It also provides convenience methods to run these ops. 618 619 Checkpoints are binary files in a proprietary format which map variable names 620 to tensor values. The best way to examine the contents of a checkpoint is to 621 load it using a `Saver`. 622 623 Savers can automatically number checkpoint filenames with a provided counter. 624 This lets you keep multiple checkpoints at different steps while training a 625 model. For example you can number the checkpoint filenames with the training 626 step number. To avoid filling up disks, savers manage checkpoint files 627 automatically. For example, they can keep only the N most recent files, or 628 one checkpoint for every N hours of training. 629 630 You number checkpoint filenames by passing a value to the optional 631 `global_step` argument to `save()`: 632 633 ```python 634 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' 635 ... 636 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' 637 ``` 638 639 Additionally, optional arguments to the `Saver()` constructor let you control 640 the proliferation of checkpoint files on disk: 641 642 * `max_to_keep` indicates the maximum number of recent checkpoint files to 643 keep. As new files are created, older files are deleted. If None or 0, 644 no checkpoints are deleted from the filesystem but only the last one is 645 kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent 646 checkpoint files are kept.) 647 648 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent 649 `max_to_keep` checkpoint files, you might want to keep one checkpoint file 650 for every N hours of training. This can be useful if you want to later 651 analyze how a model progressed during a long training session. For 652 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep 653 one checkpoint file for every 2 hours of training. The default value of 654 10,000 hours effectively disables the feature. 655 656 Note that you still have to call the `save()` method to save the model. 657 Passing these arguments to the constructor will not save variables 658 automatically for you. 659 660 A training program that saves regularly looks like: 661 662 ```python 663 ... 664 # Create a saver. 665 saver = tf.train.Saver(...variables...) 666 # Launch the graph and train, saving the model every 1,000 steps. 667 sess = tf.Session() 668 for step in xrange(1000000): 669 sess.run(..training_op..) 670 if step % 1000 == 0: 671 # Append the step number to the checkpoint name: 672 saver.save(sess, 'my-model', global_step=step) 673 ``` 674 675 In addition to checkpoint files, savers keep a protocol buffer on disk with 676 the list of recent checkpoints. This is used to manage numbered checkpoint 677 files and by `latest_checkpoint()`, which makes it easy to discover the path 678 to the most recent checkpoint. That protocol buffer is stored in a file named 679 'checkpoint' next to the checkpoint files. 680 681 If you create several savers, you can specify a different filename for the 682 protocol buffer file in the call to `save()`. 683 """ 684 685 def __init__(self, 686 var_list=None, 687 reshape=False, 688 sharded=False, 689 max_to_keep=5, 690 keep_checkpoint_every_n_hours=10000.0, 691 name=None, 692 restore_sequentially=False, 693 saver_def=None, 694 builder=None, 695 defer_build=False, 696 allow_empty=False, 697 write_version=saver_pb2.SaverDef.V2, 698 pad_step_number=False, 699 save_relative_paths=False, 700 filename=None): 701 """Creates a `Saver`. 702 703 The constructor adds ops to save and restore variables. 704 705 `var_list` specifies the variables that will be saved and restored. It can 706 be passed as a `dict` or a list: 707 708 * A `dict` of names to variables: The keys are the names that will be 709 used to save or restore the variables in the checkpoint files. 710 * A list of variables: The variables will be keyed with their op name in 711 the checkpoint files. 712 713 For example: 714 715 ```python 716 v1 = tf.Variable(..., name='v1') 717 v2 = tf.Variable(..., name='v2') 718 719 # Pass the variables as a dict: 720 saver = tf.train.Saver({'v1': v1, 'v2': v2}) 721 722 # Or pass them as a list. 723 saver = tf.train.Saver([v1, v2]) 724 # Passing a list is equivalent to passing a dict with the variable op names 725 # as keys: 726 saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}) 727 ``` 728 729 The optional `reshape` argument, if `True`, allows restoring a variable from 730 a save file where the variable had a different shape, but the same number 731 of elements and type. This is useful if you have reshaped a variable and 732 want to reload it from an older checkpoint. 733 734 The optional `sharded` argument, if `True`, instructs the saver to shard 735 checkpoints per device. 736 737 Args: 738 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping 739 names to `SaveableObject`s. If `None`, defaults to the list of all 740 saveable objects. 741 reshape: If `True`, allows restoring parameters from a checkpoint 742 where the variables have a different shape. 743 sharded: If `True`, shard the checkpoints, one per device. 744 max_to_keep: Maximum number of recent checkpoints to keep. 745 Defaults to 5. 746 keep_checkpoint_every_n_hours: How often to keep checkpoints. 747 Defaults to 10,000 hours. 748 name: String. Optional name to use as a prefix when adding operations. 749 restore_sequentially: A `Bool`, which if true, causes restore of different 750 variables to happen sequentially within each device. This can lower 751 memory usage when restoring very large models. 752 saver_def: Optional `SaverDef` proto to use instead of running the 753 builder. This is only useful for specialty code that wants to recreate 754 a `Saver` object for a previously built `Graph` that had a `Saver`. 755 The `saver_def` proto should be the one returned by the 756 `as_saver_def()` call of the `Saver` that was created for that `Graph`. 757 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. 758 Defaults to `BulkSaverBuilder()`. 759 defer_build: If `True`, defer adding the save and restore ops to the 760 `build()` call. In that case `build()` should be called before 761 finalizing the graph or using the saver. 762 allow_empty: If `False` (default) raise an error if there are no 763 variables in the graph. Otherwise, construct the saver anyway and make 764 it a no-op. 765 write_version: controls what format to use when saving checkpoints. It 766 also affects certain filepath matching logic. The V2 format is the 767 recommended choice: it is much more optimized than V1 in terms of 768 memory required and latency incurred during restore. Regardless of 769 this flag, the Saver is able to restore from both V2 and V1 checkpoints. 770 pad_step_number: if True, pads the global step number in the checkpoint 771 filepaths to some fixed width (8 by default). This is turned off by 772 default. 773 save_relative_paths: If `True`, will write relative paths to the 774 checkpoint state file. This is needed if the user wants to copy the 775 checkpoint directory and reload from the copied directory. 776 filename: If known at graph construction time, filename used for variable 777 loading/saving. 778 779 Raises: 780 TypeError: If `var_list` is invalid. 781 ValueError: If any of the keys or values in `var_list` are not unique. 782 RuntimeError: If eager execution is enabled and`var_list` does not specify 783 a list of varialbes to save. 784 785 @compatibility(eager) 786 When eager execution is enabled, `var_list` must specify a `list` or `dict` 787 of variables to save. Otherwise, a `RuntimeError` will be raised. 788 789 Although Saver works in some cases when executing eagerly, it is 790 fragile. Please switch to `tf.train.Checkpoint` or 791 `tf.keras.Model.save_weights`, which perform a more robust object-based 792 saving. These APIs will load checkpoints written by `Saver`. 793 @end_compatibility 794 """ 795 if defer_build and var_list: 796 raise ValueError( 797 "If `var_list` is provided then build cannot be deferred. " 798 "Either set defer_build=False or var_list=None.") 799 if context.executing_eagerly(): 800 logging.warning( 801 "Saver is deprecated, please switch to tf.train.Checkpoint or " 802 "tf.keras.Model.save_weights for training checkpoints. When " 803 "executing eagerly variables do not necessarily have unique names, " 804 "and so the variable.name-based lookups Saver performs are " 805 "error-prone.") 806 if var_list is None: 807 raise RuntimeError( 808 "When eager execution is enabled, `var_list` must specify a list " 809 "or dict of variables to save") 810 self._var_list = var_list 811 self._reshape = reshape 812 self._sharded = sharded 813 self._max_to_keep = max_to_keep 814 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 815 self._name = name 816 self._restore_sequentially = restore_sequentially 817 self.saver_def = saver_def 818 self._builder = builder 819 self._is_built = False 820 self._allow_empty = allow_empty 821 self._is_empty = None 822 self._write_version = write_version 823 self._pad_step_number = pad_step_number 824 self._filename = filename 825 self._last_checkpoints = [] 826 self._checkpoints_to_be_deleted = [] 827 if context.executing_eagerly(): 828 self._next_checkpoint_time = ( 829 time.time() + self._keep_checkpoint_every_n_hours * 3600) 830 elif not defer_build: 831 self.build() 832 if self.saver_def: 833 self._check_saver_def() 834 self._write_version = self.saver_def.version 835 self._save_relative_paths = save_relative_paths 836 # For compatibility with object-based checkpoints, we may build a second 837 # Saver to read the renamed keys. 838 self._object_restore_saver = None 839 840 def build(self): 841 if context.executing_eagerly(): 842 raise RuntimeError("Use save/restore instead of build in eager mode.") 843 self._build(self._filename, build_save=True, build_restore=True) 844 845 def _build_eager(self, checkpoint_path, build_save, build_restore): 846 self._build( 847 checkpoint_path, build_save=build_save, build_restore=build_restore) 848 849 def _build(self, checkpoint_path, build_save, build_restore): 850 """Builds saver_def.""" 851 if not context.executing_eagerly(): 852 if self._is_built: 853 return 854 self._is_built = True 855 856 if not self.saver_def or context.executing_eagerly(): 857 if self._builder is None: 858 self._builder = BulkSaverBuilder(self._write_version) 859 860 if self._var_list is None: 861 # pylint: disable=protected-access 862 self._var_list = variables._all_saveable_objects() 863 if not self._var_list: 864 if self._allow_empty: 865 self._is_empty = True 866 return 867 else: 868 raise ValueError("No variables to save") 869 self._is_empty = False 870 871 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access 872 self._var_list, 873 reshape=self._reshape, 874 sharded=self._sharded, 875 max_to_keep=self._max_to_keep, 876 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, 877 name=self._name, 878 restore_sequentially=self._restore_sequentially, 879 filename=checkpoint_path, 880 build_save=build_save, build_restore=build_restore) 881 elif self.saver_def and self._name: 882 # Since self._name is used as a name_scope by builder(), we are 883 # overloading the use of this field to represent the "import_scope" as 884 # well. 885 self.saver_def.filename_tensor_name = ops.prepend_name_scope( 886 self.saver_def.filename_tensor_name, self._name) 887 self.saver_def.save_tensor_name = ops.prepend_name_scope( 888 self.saver_def.save_tensor_name, self._name) 889 self.saver_def.restore_op_name = ops.prepend_name_scope( 890 self.saver_def.restore_op_name, self._name) 891 892 self._check_saver_def() 893 if not context.executing_eagerly(): 894 # Updates next checkpoint time. 895 # Set in __init__ when executing eagerly. 896 self._next_checkpoint_time = ( 897 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) 898 899 def _check_saver_def(self): 900 if not isinstance(self.saver_def, saver_pb2.SaverDef): 901 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % 902 self.saver_def) 903 if not context.executing_eagerly(): 904 if not self.saver_def.save_tensor_name: 905 raise ValueError("saver_def must specify the save_tensor_name: %s" % 906 str(self.saver_def)) 907 if not self.saver_def.restore_op_name: 908 raise ValueError("saver_def must specify the restore_op_name: %s" % 909 str(self.saver_def)) 910 911 def _CheckpointFilename(self, p): 912 """Returns the checkpoint filename given a `(filename, time)` pair. 913 914 Args: 915 p: (filename, time) pair. 916 917 Returns: 918 Checkpoint file name. 919 """ 920 name, _ = p 921 return name 922 923 def _RecordLastCheckpoint(self, latest_save_path): 924 """Manages the list of the latest checkpoints.""" 925 if not self.saver_def.max_to_keep: 926 return 927 # Remove first from list if the same name was used before. 928 for p in self._last_checkpoints: 929 if latest_save_path == self._CheckpointFilename(p): 930 self._last_checkpoints.remove(p) 931 # Append new path to list 932 self._last_checkpoints.append((latest_save_path, time.time())) 933 934 # If more than max_to_keep, remove oldest. 935 if len(self._last_checkpoints) > self.saver_def.max_to_keep: 936 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) 937 938 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): 939 """Deletes old checkpoints if necessary. 940 941 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are 942 over `max_to_keep`. They are going to be deleted. If 943 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint 944 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is 945 kept for every 0.5 hours of training; if `N` is 10, an additional 946 checkpoint is kept for every 10 hours of training. 947 948 Args: 949 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 950 """ 951 if self._checkpoints_to_be_deleted: 952 p = self._checkpoints_to_be_deleted.pop(0) 953 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we 954 # have reached N hours of training. 955 should_keep = p[1] > self._next_checkpoint_time 956 if should_keep: 957 self._next_checkpoint_time += ( 958 self.saver_def.keep_checkpoint_every_n_hours * 3600) 959 return 960 961 # Otherwise delete the files. 962 try: 963 checkpoint_management.remove_checkpoint( 964 self._CheckpointFilename(p), self.saver_def.version, 965 meta_graph_suffix) 966 except Exception as e: # pylint: disable=broad-except 967 logging.warning("Ignoring: %s", str(e)) 968 969 def as_saver_def(self): 970 """Generates a `SaverDef` representation of this saver. 971 972 Returns: 973 A `SaverDef` proto. 974 """ 975 return self.saver_def 976 977 def to_proto(self, export_scope=None): 978 """Converts this `Saver` to a `SaverDef` protocol buffer. 979 980 Args: 981 export_scope: Optional `string`. Name scope to remove. 982 983 Returns: 984 A `SaverDef` protocol buffer. 985 """ 986 if export_scope is None: 987 return self.saver_def 988 989 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and 990 self.saver_def.save_tensor_name.startswith(export_scope) and 991 self.saver_def.restore_op_name.startswith(export_scope)): 992 return None 993 994 saver_def = saver_pb2.SaverDef() 995 saver_def.CopyFrom(self.saver_def) 996 saver_def.filename_tensor_name = ops.strip_name_scope( 997 saver_def.filename_tensor_name, export_scope) 998 saver_def.save_tensor_name = ops.strip_name_scope( 999 saver_def.save_tensor_name, export_scope) 1000 saver_def.restore_op_name = ops.strip_name_scope( 1001 saver_def.restore_op_name, export_scope) 1002 return saver_def 1003 1004 @staticmethod 1005 def from_proto(saver_def, import_scope=None): 1006 """Returns a `Saver` object created from `saver_def`. 1007 1008 Args: 1009 saver_def: a `SaverDef` protocol buffer. 1010 import_scope: Optional `string`. Name scope to use. 1011 1012 Returns: 1013 A `Saver` built from saver_def. 1014 """ 1015 return Saver(saver_def=saver_def, name=import_scope) 1016 1017 @property 1018 def last_checkpoints(self): 1019 """List of not-yet-deleted checkpoint filenames. 1020 1021 You can pass any of the returned values to `restore()`. 1022 1023 Returns: 1024 A list of checkpoint filenames, sorted from oldest to newest. 1025 """ 1026 return list(self._CheckpointFilename(p) for p in self._last_checkpoints) 1027 1028 def set_last_checkpoints(self, last_checkpoints): 1029 """DEPRECATED: Use set_last_checkpoints_with_time. 1030 1031 Sets the list of old checkpoint filenames. 1032 1033 Args: 1034 last_checkpoints: A list of checkpoint filenames. 1035 1036 Raises: 1037 AssertionError: If last_checkpoints is not a list. 1038 """ 1039 assert isinstance(last_checkpoints, list) 1040 # We use a timestamp of +inf so that this checkpoint will never be 1041 # deleted. This is both safe and backwards compatible to a previous 1042 # version of the code which used s[1] as the "timestamp". 1043 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] 1044 1045 def set_last_checkpoints_with_time(self, last_checkpoints_with_time): 1046 """Sets the list of old checkpoint filenames and timestamps. 1047 1048 Args: 1049 last_checkpoints_with_time: A list of tuples of checkpoint filenames and 1050 timestamps. 1051 1052 Raises: 1053 AssertionError: If last_checkpoints_with_time is not a list. 1054 """ 1055 assert isinstance(last_checkpoints_with_time, list) 1056 self._last_checkpoints = last_checkpoints_with_time 1057 1058 def recover_last_checkpoints(self, checkpoint_paths): 1059 """Recovers the internal saver state after a crash. 1060 1061 This method is useful for recovering the "self._last_checkpoints" state. 1062 1063 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files 1064 exist, use their mtime as the checkpoint timestamp. 1065 1066 Args: 1067 checkpoint_paths: a list of checkpoint paths. 1068 """ 1069 mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths) 1070 self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) 1071 1072 def save(self, 1073 sess, 1074 save_path, 1075 global_step=None, 1076 latest_filename=None, 1077 meta_graph_suffix="meta", 1078 write_meta_graph=True, 1079 write_state=True, 1080 strip_default_attrs=False, 1081 save_debug_info=False): 1082 # pylint: disable=line-too-long 1083 """Saves variables. 1084 1085 This method runs the ops added by the constructor for saving variables. 1086 It requires a session in which the graph was launched. The variables to 1087 save must also have been initialized. 1088 1089 The method returns the path prefix of the newly created checkpoint files. 1090 This string can be passed directly to a call to `restore()`. 1091 1092 Args: 1093 sess: A Session to use to save the variables. 1094 save_path: String. Prefix of filenames created for the checkpoint. 1095 global_step: If provided the global step number is appended to 1096 `save_path` to create the checkpoint filenames. The optional argument 1097 can be a `Tensor`, a `Tensor` name or an integer. 1098 latest_filename: Optional name for the protocol buffer file that will 1099 contains the list of most recent checkpoints. That file, 1100 kept in the same directory as the checkpoint files, is automatically 1101 managed by the saver to keep track of recent checkpoints. Defaults to 1102 'checkpoint'. 1103 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1104 write_meta_graph: `Boolean` indicating whether or not to write the meta 1105 graph file. 1106 write_state: `Boolean` indicating whether or not to write the 1107 `CheckpointStateProto`. 1108 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1109 removed from the NodeDefs. For a detailed guide, see 1110 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1111 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1112 which in the same directory of save_path and with `_debug` added before 1113 the file extension. This is only enabled when `write_meta_graph` is 1114 `True` 1115 1116 Returns: 1117 A string: path prefix used for the checkpoint files. If the saver is 1118 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' 1119 is the number of shards created. 1120 If the saver is empty, returns None. 1121 1122 Raises: 1123 TypeError: If `sess` is not a `Session`. 1124 ValueError: If `latest_filename` contains path components, or if it 1125 collides with `save_path`. 1126 RuntimeError: If save and restore ops weren't built. 1127 """ 1128 # pylint: enable=line-too-long 1129 if not self._is_built and not context.executing_eagerly(): 1130 raise RuntimeError( 1131 "`build()` should be called before save if defer_build==True") 1132 if latest_filename is None: 1133 latest_filename = "checkpoint" 1134 if self._write_version != saver_pb2.SaverDef.V2: 1135 logging.warning("*******************************************************") 1136 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") 1137 logging.warning("Consider switching to the more efficient V2 format:") 1138 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") 1139 logging.warning("now on by default.") 1140 logging.warning("*******************************************************") 1141 1142 if os.path.split(latest_filename)[0]: 1143 raise ValueError("'latest_filename' must not contain path components") 1144 1145 if global_step is not None: 1146 if not isinstance(global_step, compat.integral_types): 1147 global_step = training_util.global_step(sess, global_step) 1148 checkpoint_file = "%s-%d" % (save_path, global_step) 1149 if self._pad_step_number: 1150 # Zero-pads the step numbers, so that they are sorted when listed. 1151 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) 1152 else: 1153 checkpoint_file = save_path 1154 if os.path.basename( 1155 save_path) == latest_filename and not self._sharded: 1156 # Guard against collision between data file and checkpoint state file. 1157 raise ValueError( 1158 "'latest_filename' collides with 'save_path': '%s' and '%s'" % 1159 (latest_filename, save_path)) 1160 1161 if (not context.executing_eagerly() and 1162 not isinstance(sess, session.SessionInterface)): 1163 raise TypeError("'sess' must be a Session; %s" % sess) 1164 1165 save_path_parent = os.path.dirname(save_path) 1166 if not self._is_empty: 1167 try: 1168 if context.executing_eagerly(): 1169 self._build_eager( 1170 checkpoint_file, build_save=True, build_restore=False) 1171 model_checkpoint_path = self.saver_def.save_tensor_name 1172 else: 1173 model_checkpoint_path = sess.run( 1174 self.saver_def.save_tensor_name, 1175 {self.saver_def.filename_tensor_name: checkpoint_file}) 1176 1177 model_checkpoint_path = compat.as_str(model_checkpoint_path) 1178 if write_state: 1179 self._RecordLastCheckpoint(model_checkpoint_path) 1180 checkpoint_management.update_checkpoint_state_internal( 1181 save_dir=save_path_parent, 1182 model_checkpoint_path=model_checkpoint_path, 1183 all_model_checkpoint_paths=self.last_checkpoints, 1184 latest_filename=latest_filename, 1185 save_relative_paths=self._save_relative_paths) 1186 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) 1187 except (errors.FailedPreconditionError, errors.NotFoundError) as exc: 1188 if not gfile.IsDirectory(save_path_parent): 1189 exc = ValueError( 1190 "Parent directory of {} doesn't exist, can't save.".format( 1191 save_path)) 1192 raise exc 1193 1194 if write_meta_graph: 1195 meta_graph_filename = checkpoint_management.meta_graph_filename( 1196 checkpoint_file, meta_graph_suffix=meta_graph_suffix) 1197 if not context.executing_eagerly(): 1198 with sess.graph.as_default(): 1199 self.export_meta_graph( 1200 meta_graph_filename, strip_default_attrs=strip_default_attrs, 1201 save_debug_info=save_debug_info) 1202 1203 if self._is_empty: 1204 return None 1205 else: 1206 return model_checkpoint_path 1207 1208 def export_meta_graph(self, 1209 filename=None, 1210 collection_list=None, 1211 as_text=False, 1212 export_scope=None, 1213 clear_devices=False, 1214 clear_extraneous_savers=False, 1215 strip_default_attrs=False, 1216 save_debug_info=False): 1217 # pylint: disable=line-too-long 1218 """Writes `MetaGraphDef` to save_path/filename. 1219 1220 Args: 1221 filename: Optional meta_graph filename including the path. 1222 collection_list: List of string keys to collect. 1223 as_text: If `True`, writes the meta_graph as an ASCII proto. 1224 export_scope: Optional `string`. Name scope to remove. 1225 clear_devices: Whether or not to clear the device field for an `Operation` 1226 or `Tensor` during export. 1227 clear_extraneous_savers: Remove any Saver-related information from the 1228 graph (both Save/Restore ops and SaverDefs) that are not associated 1229 with this Saver. 1230 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1231 removed from the NodeDefs. For a detailed guide, see 1232 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1233 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1234 which in the same directory of filename and with `_debug` added before 1235 the file extension. 1236 1237 Returns: 1238 A `MetaGraphDef` proto. 1239 """ 1240 # pylint: enable=line-too-long 1241 return export_meta_graph( 1242 filename=filename, 1243 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 1244 saver_def=self.saver_def, 1245 collection_list=collection_list, 1246 as_text=as_text, 1247 export_scope=export_scope, 1248 clear_devices=clear_devices, 1249 clear_extraneous_savers=clear_extraneous_savers, 1250 strip_default_attrs=strip_default_attrs, 1251 save_debug_info=save_debug_info) 1252 1253 def restore(self, sess, save_path): 1254 """Restores previously saved variables. 1255 1256 This method runs the ops added by the constructor for restoring variables. 1257 It requires a session in which the graph was launched. The variables to 1258 restore do not have to have been initialized, as restoring is itself a way 1259 to initialize variables. 1260 1261 The `save_path` argument is typically a value previously returned from a 1262 `save()` call, or a call to `latest_checkpoint()`. 1263 1264 Args: 1265 sess: A `Session` to use to restore the parameters. None in eager mode. 1266 save_path: Path where parameters were previously saved. 1267 1268 Raises: 1269 ValueError: If save_path is None or not a valid checkpoint. 1270 """ 1271 if self._is_empty: 1272 return 1273 if save_path is None: 1274 raise ValueError("Can't load save_path when it is None.") 1275 1276 if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)): 1277 raise ValueError("The passed save_path is not a valid checkpoint: " 1278 + compat.as_text(save_path)) 1279 1280 logging.info("Restoring parameters from %s", compat.as_text(save_path)) 1281 try: 1282 if context.executing_eagerly(): 1283 self._build_eager(save_path, build_save=False, build_restore=True) 1284 else: 1285 sess.run(self.saver_def.restore_op_name, 1286 {self.saver_def.filename_tensor_name: save_path}) 1287 except errors.NotFoundError as err: 1288 # There are three common conditions that might cause this error: 1289 # 0. The file is missing. We ignore here, as this is checked above. 1290 # 1. This is an object-based checkpoint trying name-based loading. 1291 # 2. The graph has been altered and a variable or other name is missing. 1292 1293 # 1. The checkpoint would not be loaded successfully as is. Try to parse 1294 # it as an object-based checkpoint. 1295 try: 1296 names_to_keys = object_graph_key_mapping(save_path) 1297 except errors.NotFoundError: 1298 # 2. This is not an object-based checkpoint, which likely means there 1299 # is a graph mismatch. Re-raise the original error with 1300 # a helpful message (b/110263146) 1301 raise _wrap_restore_error_with_msg( 1302 err, "a Variable name or other graph key that is missing") 1303 1304 # This is an object-based checkpoint. We'll print a warning and then do 1305 # the restore. 1306 logging.warning( 1307 "Restoring an object-based checkpoint using a name-based saver. This " 1308 "may be somewhat fragile, and will re-build the Saver. Instead, " 1309 "consider loading object-based checkpoints using " 1310 "tf.train.Checkpoint().") 1311 self._object_restore_saver = saver_from_object_based_checkpoint( 1312 checkpoint_path=save_path, 1313 var_list=self._var_list, 1314 builder=self._builder, 1315 names_to_keys=names_to_keys, 1316 cached_saver=self._object_restore_saver) 1317 self._object_restore_saver.restore(sess=sess, save_path=save_path) 1318 except errors.InvalidArgumentError as err: 1319 # There is a mismatch between the graph and the checkpoint being loaded. 1320 # We add a more reasonable error message here to help users (b/110263146) 1321 raise _wrap_restore_error_with_msg( 1322 err, "a mismatch between the current graph and the graph") 1323 1324 @staticmethod 1325 def _add_collection_def(meta_graph_def, key, export_scope=None): 1326 """Adds a collection to MetaGraphDef protocol buffer. 1327 1328 Args: 1329 meta_graph_def: MetaGraphDef protocol buffer. 1330 key: One of the GraphKeys or user-defined string. 1331 export_scope: Optional `string`. Name scope to remove. 1332 """ 1333 meta_graph.add_collection_def(meta_graph_def, key, 1334 export_scope=export_scope) 1335 1336 1337@tf_export(v1=["train.import_meta_graph"]) 1338def import_meta_graph(meta_graph_or_file, clear_devices=False, 1339 import_scope=None, **kwargs): 1340 """Recreates a Graph saved in a `MetaGraphDef` proto. 1341 1342 This function takes a `MetaGraphDef` protocol buffer as input. If 1343 the argument is a file containing a `MetaGraphDef` protocol buffer , 1344 it constructs a protocol buffer from the file content. The function 1345 then adds all the nodes from the `graph_def` field to the 1346 current graph, recreates all the collections, and returns a saver 1347 constructed from the `saver_def` field. 1348 1349 In combination with `export_meta_graph()`, this function can be used to 1350 1351 * Serialize a graph along with other Python objects such as `QueueRunner`, 1352 `Variable` into a `MetaGraphDef`. 1353 1354 * Restart training from a saved graph and checkpoints. 1355 1356 * Run inference from a saved graph and checkpoints. 1357 1358 ```Python 1359 ... 1360 # Create a saver. 1361 saver = tf.train.Saver(...variables...) 1362 # Remember the training_op we want to run by adding it to a collection. 1363 tf.add_to_collection('train_op', train_op) 1364 sess = tf.Session() 1365 for step in xrange(1000000): 1366 sess.run(train_op) 1367 if step % 1000 == 0: 1368 # Saves checkpoint, which by default also exports a meta_graph 1369 # named 'my-model-global_step.meta'. 1370 saver.save(sess, 'my-model', global_step=step) 1371 ``` 1372 1373 Later we can continue training from this saved `meta_graph` without building 1374 the model from scratch. 1375 1376 ```Python 1377 with tf.Session() as sess: 1378 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 1379 new_saver.restore(sess, 'my-save-dir/my-model-10000') 1380 # tf.get_collection() returns a list. In this example we only want the 1381 # first one. 1382 train_op = tf.get_collection('train_op')[0] 1383 for step in xrange(1000000): 1384 sess.run(train_op) 1385 ``` 1386 1387 NOTE: Restarting training from saved `meta_graph` only works if the 1388 device assignments have not changed. 1389 1390 Example 2: 1391 Variables, placeholders, and independent operations can also be stored, as 1392 shown in the following example. 1393 1394 ```Python 1395 # Saving contents and operations. 1396 v1 = tf.placeholder(tf.float32, name="v1") 1397 v2 = tf.placeholder(tf.float32, name="v2") 1398 v3 = tf.mul(v1, v2) 1399 vx = tf.Variable(10.0, name="vx") 1400 v4 = tf.add(v3, vx, name="v4") 1401 saver = tf.train.Saver([vx]) 1402 sess = tf.Session() 1403 sess.run(tf.initialize_all_variables()) 1404 sess.run(vx.assign(tf.add(vx, vx))) 1405 result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) 1406 print(result) 1407 saver.save(sess, "./model_ex1") 1408 ``` 1409 1410 Later this model can be restored and contents loaded. 1411 1412 ```Python 1413 # Restoring variables and running operations. 1414 saver = tf.train.import_meta_graph("./model_ex1.meta") 1415 sess = tf.Session() 1416 saver.restore(sess, "./model_ex1") 1417 result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 1418 print(result) 1419 ``` 1420 1421 Args: 1422 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 1423 the path) containing a `MetaGraphDef`. 1424 clear_devices: Whether or not to clear the device field for an `Operation` 1425 or `Tensor` during import. 1426 import_scope: Optional `string`. Name scope to add. Only used when 1427 initializing from protocol buffer. 1428 **kwargs: Optional keyed arguments. 1429 1430 Returns: 1431 A saver constructed from `saver_def` in `MetaGraphDef` or None. 1432 1433 A None value is returned if no variables exist in the `MetaGraphDef` 1434 (i.e., there are no variables to restore). 1435 1436 Raises: 1437 RuntimeError: If called with eager execution enabled. 1438 1439 @compatibility(eager) 1440 Exporting/importing meta graphs is not supported. No graph exists when eager 1441 execution is enabled. 1442 @end_compatibility 1443 """ # pylint: disable=g-doc-exception 1444 return _import_meta_graph_with_return_elements( 1445 meta_graph_or_file, clear_devices, import_scope, **kwargs)[0] 1446 1447 1448def _import_meta_graph_with_return_elements( 1449 meta_graph_or_file, clear_devices=False, import_scope=None, 1450 return_elements=None, **kwargs): 1451 """Import MetaGraph, and return both a saver and returned elements.""" 1452 if context.executing_eagerly(): 1453 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1454 "eager execution is enabled. No graph exists when eager " 1455 "execution is enabled.") 1456 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 1457 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) 1458 else: 1459 meta_graph_def = meta_graph_or_file 1460 1461 imported_vars, imported_return_elements = ( 1462 meta_graph.import_scoped_meta_graph_with_return_elements( 1463 meta_graph_def, 1464 clear_devices=clear_devices, 1465 import_scope=import_scope, 1466 return_elements=return_elements, 1467 **kwargs)) 1468 1469 saver = _create_saver_from_imported_meta_graph( 1470 meta_graph_def, import_scope, imported_vars) 1471 return saver, imported_return_elements 1472 1473 1474def _create_saver_from_imported_meta_graph( 1475 meta_graph_def, import_scope, imported_vars): 1476 """Return a saver for restoring variable values to an imported MetaGraph.""" 1477 if meta_graph_def.HasField("saver_def"): 1478 # Infer the scope that is prepended by `import_scoped_meta_graph`. 1479 scope = import_scope 1480 var_names = list(imported_vars.keys()) 1481 if var_names: 1482 sample_key = var_names[0] 1483 sample_var = imported_vars[sample_key] 1484 scope = sample_var.name[:-len(sample_key)] 1485 1486 return Saver(saver_def=meta_graph_def.saver_def, name=scope) 1487 else: 1488 if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access 1489 # Return the default saver instance for all graph variables. 1490 return Saver() 1491 else: 1492 # If no graph variables exist, then a Saver cannot be constructed. 1493 logging.info("Saver not created because there are no variables in the" 1494 " graph to restore") 1495 return None 1496 1497 1498@tf_export(v1=["train.export_meta_graph"]) 1499def export_meta_graph(filename=None, 1500 meta_info_def=None, 1501 graph_def=None, 1502 saver_def=None, 1503 collection_list=None, 1504 as_text=False, 1505 graph=None, 1506 export_scope=None, 1507 clear_devices=False, 1508 clear_extraneous_savers=False, 1509 strip_default_attrs=False, 1510 save_debug_info=False, 1511 **kwargs): 1512 # pylint: disable=line-too-long 1513 """Returns `MetaGraphDef` proto. Optionally writes it to filename. 1514 1515 This function exports the graph, saver, and collection objects into 1516 `MetaGraphDef` protocol buffer with the intention of it being imported 1517 at a later time or location to restart training, run inference, or be 1518 a subgraph. 1519 1520 Args: 1521 filename: Optional filename including the path for writing the 1522 generated `MetaGraphDef` protocol buffer. 1523 meta_info_def: `MetaInfoDef` protocol buffer. 1524 graph_def: `GraphDef` protocol buffer. 1525 saver_def: `SaverDef` protocol buffer. 1526 collection_list: List of string keys to collect. 1527 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 1528 graph: The `Graph` to export. If `None`, use the default graph. 1529 export_scope: Optional `string`. Name scope under which to extract 1530 the subgraph. The scope name will be striped from the node definitions 1531 for easy import later into new name scopes. If `None`, the whole graph 1532 is exported. graph_def and export_scope cannot both be specified. 1533 clear_devices: Whether or not to clear the device field for an `Operation` 1534 or `Tensor` during export. 1535 clear_extraneous_savers: Remove any Saver-related information from the 1536 graph (both Save/Restore ops and SaverDefs) that are not associated 1537 with the provided SaverDef. 1538 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1539 removed from the NodeDefs. For a detailed guide, see 1540 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1541 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1542 which in the same directory of filename and with `_debug` added before 1543 the file extend. 1544 **kwargs: Optional keyed arguments. 1545 1546 Returns: 1547 A `MetaGraphDef` proto. 1548 1549 Raises: 1550 ValueError: When the `GraphDef` is larger than 2GB. 1551 RuntimeError: If called with eager execution enabled. 1552 1553 @compatibility(eager) 1554 Exporting/importing meta graphs is not supported unless both `graph_def` and 1555 `graph` are provided. No graph exists when eager execution is enabled. 1556 @end_compatibility 1557 """ 1558 # pylint: enable=line-too-long 1559 if context.executing_eagerly() and not (graph_def is not None and 1560 graph is not None): 1561 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1562 "eager execution is enabled. No graph exists when eager " 1563 "execution is enabled.") 1564 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 1565 filename=filename, 1566 meta_info_def=meta_info_def, 1567 graph_def=graph_def, 1568 saver_def=saver_def, 1569 collection_list=collection_list, 1570 as_text=as_text, 1571 graph=graph, 1572 export_scope=export_scope, 1573 clear_devices=clear_devices, 1574 clear_extraneous_savers=clear_extraneous_savers, 1575 strip_default_attrs=strip_default_attrs, 1576 save_debug_info=save_debug_info, 1577 **kwargs) 1578 return meta_graph_def 1579 1580 1581def _wrap_restore_error_with_msg(err, extra_verbiage): 1582 err_msg = ("Restoring from checkpoint failed. This is most likely " 1583 "due to {} from the checkpoint. Please ensure that you " 1584 "have not altered the graph expected based on the checkpoint. " 1585 "Original error:\n\n{}").format(extra_verbiage, err.message) 1586 return err.__class__(err.node_def, err.op, err_msg) 1587 1588 1589ops.register_proto_function( 1590 ops.GraphKeys.SAVERS, 1591 proto_type=saver_pb2.SaverDef, 1592 to_proto=Saver.to_proto, 1593 from_proto=Saver.from_proto) 1594 1595 1596def object_graph_key_mapping(checkpoint_path): 1597 """Return name to key mappings from the checkpoint. 1598 1599 Args: 1600 checkpoint_path: string, path to object-based checkpoint 1601 1602 Returns: 1603 Dictionary mapping tensor names to checkpoint keys. 1604 """ 1605 reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 1606 object_graph_string = reader.get_tensor( 1607 trackable.OBJECT_GRAPH_PROTO_KEY) 1608 object_graph_proto = ( 1609 trackable_object_graph_pb2.TrackableObjectGraph()) 1610 object_graph_proto.ParseFromString(object_graph_string) 1611 names_to_keys = {} 1612 for node in object_graph_proto.nodes: 1613 for attribute in node.attributes: 1614 names_to_keys[attribute.full_name] = attribute.checkpoint_key 1615 return names_to_keys 1616 1617 1618def saver_from_object_based_checkpoint( 1619 checkpoint_path, var_list=None, builder=None, names_to_keys=None, 1620 cached_saver=None): 1621 """Return a `Saver` which reads from an object-based checkpoint. 1622 1623 This function validates that all variables in the variables list are remapped 1624 in the object-based checkpoint (or `names_to_keys` dict if provided). A 1625 saver will be created with the list of remapped variables. 1626 1627 The `cached_saver` argument allows the user to pass in a previously created 1628 saver, so multiple `saver.restore()` calls don't pollute the graph when graph 1629 building. This assumes that keys are consistent, meaning that the 1630 1) `checkpoint_path` checkpoint, and 1631 2) checkpoint used to create the `cached_saver` 1632 are the same type of object-based checkpoint. If this argument is set, this 1633 function will simply validate that all variables have been remapped by the 1634 checkpoint at `checkpoint_path`. 1635 1636 Note that in general, `tf.train.Checkpoint` should be used to restore/save an 1637 object-based checkpoint. 1638 1639 Args: 1640 checkpoint_path: string, path to object-based checkpoint 1641 var_list: list of `Variables` that appear in the checkpoint. If `None`, 1642 `var_list` will be set to all saveable objects. 1643 builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder` 1644 will be created. 1645 names_to_keys: dict mapping string tensor names to checkpooint keys. If 1646 `None`, this dict will be generated from the checkpoint file. 1647 cached_saver: Cached `Saver` object with remapped variables. 1648 1649 Returns: 1650 `Saver` with remapped variables for reading from an object-based checkpoint. 1651 1652 Raises: 1653 ValueError if the checkpoint provided is not an object-based checkpoint. 1654 NotFoundError: If one of the variables in `var_list` can not be found in the 1655 checkpoint. This could mean the checkpoint or `names_to_keys` mapping is 1656 missing the variable. 1657 """ 1658 if names_to_keys is None: 1659 try: 1660 names_to_keys = object_graph_key_mapping(checkpoint_path) 1661 except errors.NotFoundError: 1662 raise ValueError("Checkpoint in %s not an object-based checkpoint." 1663 % checkpoint_path) 1664 if var_list is None: 1665 var_list = variables._all_saveable_objects() # pylint: disable=protected-access 1666 if builder is None: 1667 builder = BulkSaverBuilder() 1668 1669 saveables = saveable_object_util.validate_and_slice_inputs(var_list) 1670 current_names = set() 1671 for saveable in saveables: 1672 for spec in saveable.specs: 1673 current_names.add(spec.name) 1674 previous_names = set(names_to_keys.keys()) 1675 missing_names = current_names - previous_names 1676 if missing_names: 1677 extra_names = previous_names - current_names 1678 intersecting_names = previous_names.intersection(current_names) 1679 raise errors.NotFoundError( 1680 None, None, 1681 message=( 1682 "\n\nExisting variables not in the checkpoint: %s\n\n" 1683 "Variables names when this checkpoint was written which don't " 1684 "exist now: %s\n\n" 1685 "(%d variable name(s) did match)\n\n" 1686 "Could not find some variables in the checkpoint (see names " 1687 "above). Saver was attempting to load an object-based checkpoint " 1688 "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) " 1689 "using variable names. If the checkpoint was written with eager " 1690 "execution enabled, it's possible that variable names have " 1691 "changed (for example missing a '_1' suffix). It's also " 1692 "possible that there are new variables which did not exist " 1693 "when the checkpoint was written. You can construct a " 1694 "Saver(var_list=...) with only the variables which previously " 1695 "existed, and if variable names have changed you may need to " 1696 "make this a dictionary with the old names as keys. If you're " 1697 "using an Estimator, you'll need to return a tf.train.Saver " 1698 "inside a tf.train.Scaffold from your model_fn.") 1699 % (", ".join(sorted(missing_names)), ", ".join(sorted(extra_names)), 1700 len(intersecting_names))) 1701 for saveable in saveables: 1702 for spec in saveable.specs: 1703 spec.name = names_to_keys[spec.name] 1704 if cached_saver is None: 1705 return Saver(saveables) 1706 return cached_saver 1707