1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables. 18 19Symbols in this file are deprecated. See replacements in 20tensorflow/python/training/trackable and tensorflow/python/training/saving. 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27import os.path 28import time 29 30import numpy as np 31from tensorflow.core.protobuf import meta_graph_pb2 32from tensorflow.core.protobuf import saver_pb2 33from tensorflow.core.protobuf import trackable_object_graph_pb2 34from tensorflow.python.client import session 35from tensorflow.python.eager import context 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import device as pydev 38from tensorflow.python.framework import errors 39from tensorflow.python.framework import meta_graph 40from tensorflow.python.framework import ops 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import control_flow_ops 43from tensorflow.python.ops import gen_io_ops 44from tensorflow.python.ops import io_ops 45from tensorflow.python.ops import string_ops 46from tensorflow.python.ops import variables 47from tensorflow.python.platform import gfile 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.training import checkpoint_management 50from tensorflow.python.training import py_checkpoint_reader 51from tensorflow.python.training import training_util 52from tensorflow.python.training.saving import saveable_object 53from tensorflow.python.training.saving import saveable_object_util 54from tensorflow.python.training.tracking import base as trackable 55from tensorflow.python.util import compat 56from tensorflow.python.util.tf_export import tf_export 57 58# TODO(allenl): Remove these aliases once all users are migrated off. 59get_checkpoint_state = checkpoint_management.get_checkpoint_state 60update_checkpoint_state = checkpoint_management.update_checkpoint_state 61generate_checkpoint_state_proto = ( 62 checkpoint_management.generate_checkpoint_state_proto) 63latest_checkpoint = checkpoint_management.latest_checkpoint 64checkpoint_exists = checkpoint_management.checkpoint_exists 65get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes 66remove_checkpoint = checkpoint_management.remove_checkpoint 67 68 69class BaseSaverBuilder(object): 70 """Base class for Savers. 71 72 Can be extended to create different Ops. 73 """ 74 75 SaveSpec = saveable_object.SaveSpec 76 SaveableObject = saveable_object.SaveableObject 77 78 # Aliases for code which was moved but still has lots of users. 79 VariableSaveable = saveable_object_util.ReferenceVariableSaveable 80 ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable 81 82 def __init__(self, write_version=saver_pb2.SaverDef.V2): 83 self._write_version = write_version 84 85 def save_op(self, filename_tensor, saveables): 86 """Create an Op to save 'saveables'. 87 88 This is intended to be overridden by subclasses that want to generate 89 different Ops. 90 91 Args: 92 filename_tensor: String Tensor. 93 saveables: A list of BaseSaverBuilder.SaveableObject objects. 94 95 Returns: 96 An Operation that save the variables. 97 98 Raises: 99 RuntimeError: (implementation detail) if "self._write_version" is an 100 unexpected value. 101 """ 102 # pylint: disable=protected-access 103 tensor_names = [] 104 tensors = [] 105 tensor_slices = [] 106 for saveable in saveables: 107 for spec in saveable.specs: 108 tensor_names.append(spec.name) 109 tensors.append(spec.tensor) 110 tensor_slices.append(spec.slice_spec) 111 if self._write_version == saver_pb2.SaverDef.V1: 112 return io_ops._save( 113 filename=filename_tensor, 114 tensor_names=tensor_names, 115 tensors=tensors, 116 tensor_slices=tensor_slices) 117 elif self._write_version == saver_pb2.SaverDef.V2: 118 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix 119 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". 120 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, 121 tensors) 122 else: 123 raise RuntimeError("Unexpected write_version: " + self._write_version) 124 125 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 126 restore_sequentially): 127 """Restore all tensors contained in saveables. 128 129 By default, this issues separate calls to `restore_op` for each saveable. 130 Subclasses may override to load multiple saveables in a single call. 131 132 Args: 133 filename_tensor: String Tensor. 134 saveables: List of BaseSaverBuilder.SaveableObject objects. 135 preferred_shard: Int. Shard to open first when loading a sharded file. 136 restore_sequentially: Unused. Bool. If true, each restore is sequential. 137 138 Returns: 139 A list of Tensors resulting from reading 'saveable' from 140 'filename'. 141 142 """ 143 del restore_sequentially 144 all_tensors = [] 145 for saveable in saveables: 146 if saveable.device: 147 device = saveable_object_util.set_cpu0(saveable.device) 148 else: 149 device = None 150 with ops.device(device): 151 all_tensors.extend( 152 self.restore_op(filename_tensor, saveable, preferred_shard)) 153 return all_tensors 154 155 # pylint: disable=unused-argument 156 def restore_op(self, filename_tensor, saveable, preferred_shard): 157 """Create ops to restore 'saveable'. 158 159 This is intended to be overridden by subclasses that want to generate 160 different Ops. 161 162 Args: 163 filename_tensor: String Tensor. 164 saveable: A BaseSaverBuilder.SaveableObject object. 165 preferred_shard: Int. Shard to open first when loading a sharded file. 166 167 Returns: 168 A list of Tensors resulting from reading 'saveable' from 169 'filename'. 170 """ 171 # pylint: disable=protected-access 172 tensors = [] 173 for spec in saveable.specs: 174 tensors.append( 175 io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec], 176 [spec.dtype])[0]) 177 178 return tensors 179 180 # pylint: enable=unused-argument 181 182 def sharded_filename(self, filename_tensor, shard, num_shards): 183 """Append sharding information to a filename. 184 185 Args: 186 filename_tensor: A string tensor. 187 shard: Integer. The shard for the filename. 188 num_shards: An int Tensor for the number of shards. 189 190 Returns: 191 A string tensor. 192 """ 193 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 194 195 def _AddSaveOps(self, filename_tensor, saveables): 196 """Add ops to save variables that are on the same shard. 197 198 Args: 199 filename_tensor: String Tensor. 200 saveables: A list of SaveableObject objects. 201 202 Returns: 203 A tensor with the filename used to save. 204 """ 205 save = self.save_op(filename_tensor, saveables) 206 return control_flow_ops.with_dependencies([save], filename_tensor) 207 208 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): 209 """Add ops to save the params per shard, for the V2 format. 210 211 Note that the sharded save procedure for the V2 format is different from 212 V1: there is a special "merge" step that merges the small metadata produced 213 from each device. 214 215 Args: 216 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*, 217 but as a prefix of a V2 checkpoint; 218 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as 219 returned by _GroupByDevices(). 220 221 Returns: 222 An op to save the variables, which, when evaluated, returns the prefix 223 "<user-fed prefix>" only and does not include the sharded spec suffix. 224 """ 225 # IMPLEMENTATION DETAILS: most clients should skip. 226 # 227 # Suffix for any well-formed "checkpoint_prefix", when sharded. 228 # Transformations: 229 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 230 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>. 231 # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it 232 # * Otherwise _temp/part is appended which is normalized relative to the OS 233 # Example: 234 # During runtime, a temporary directory is first created, which contains 235 # files 236 # 237 # <train dir>/myckpt_temp/ 238 # part-?????-of-?????{.index, .data-00000-of-00001} 239 # 240 # Before .save() finishes, they will be (hopefully, atomically) renamed to 241 # 242 # <train dir>/ 243 # myckpt{.index, .data-?????-of-?????} 244 # 245 # Filesystems with eventual consistency (such as S3), don't need a 246 # temporary location. Using a temporary directory in those cases might 247 # cause situations where files are not available during copy. 248 # 249 # Users only need to interact with the user-specified prefix, which is 250 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 251 # prefix directly, instead of any physical pathname. (On failure and 252 # subsequent restore, an outdated and orphaned temporary directory can be 253 # safely removed.) 254 with ops.device("CPU"): 255 _SHARDED_SUFFIX = array_ops.where( 256 string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"), 257 constant_op.constant(".part"), 258 constant_op.constant(os.path.normpath("_temp/part"))) 259 tmp_checkpoint_prefix = string_ops.string_join( 260 [checkpoint_prefix, _SHARDED_SUFFIX]) 261 262 num_shards = len(per_device) 263 sharded_saves = [] 264 sharded_prefixes = [] 265 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 266 last_device = None 267 for shard, (device, saveables) in enumerate(per_device): 268 last_device = device 269 with ops.device(saveable_object_util.set_cpu0(device)): 270 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, 271 num_shards_tensor) 272 sharded_prefixes.append(sharded_filename) 273 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 274 275 with ops.control_dependencies([x.op for x in sharded_saves]): 276 # Co-locates the merge step with the last device. 277 with ops.device(saveable_object_util.set_cpu0(last_device)): 278 # V2 format write path consists of a metadata merge step. Once merged, 279 # attempts to delete the temporary directory, "<user-fed prefix>_temp". 280 merge_step = gen_io_ops.merge_v2_checkpoints( 281 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) 282 with ops.control_dependencies([merge_step]): 283 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the 284 # sharded spec suffix. 285 return array_ops.identity(checkpoint_prefix) 286 287 def _AddShardedSaveOps(self, filename_tensor, per_device): 288 """Add ops to save the params per shard. 289 290 Args: 291 filename_tensor: a scalar String Tensor. 292 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as 293 returned by _GroupByDevices(). 294 295 Returns: 296 An op to save the variables. 297 """ 298 if self._write_version == saver_pb2.SaverDef.V2: 299 return self._AddShardedSaveOpsForV2(filename_tensor, per_device) 300 301 num_shards = len(per_device) 302 sharded_saves = [] 303 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 304 for shard, (device, saveables) in enumerate(per_device): 305 with ops.device(device): 306 sharded_filename = self.sharded_filename(filename_tensor, shard, 307 num_shards_tensor) 308 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 309 # Return the sharded name for the save path. 310 with ops.control_dependencies([x.op for x in sharded_saves]): 311 return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor) 312 313 def _AddRestoreOps(self, 314 filename_tensor, 315 saveables, 316 restore_sequentially, 317 reshape, 318 preferred_shard=-1, 319 name="restore_all"): 320 """Add operations to restore saveables. 321 322 Args: 323 filename_tensor: Tensor for the path of the file to load. 324 saveables: A list of SaveableObject objects. 325 restore_sequentially: True if we want to restore variables sequentially 326 within a shard. 327 reshape: True if we want to reshape loaded tensors to the shape of the 328 corresponding variable. 329 preferred_shard: Shard to open first when loading a sharded file. 330 name: Name for the returned op. 331 332 Returns: 333 An Operation that restores the variables. 334 """ 335 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, 336 restore_sequentially) 337 338 assign_ops = [] 339 idx = 0 340 # Load and optionally reshape on the CPU, as string tensors are not 341 # available on the GPU. 342 # TODO(touts): Re-enable restore on GPU when we can support annotating 343 # string tensors as "HostMemory" inputs. 344 for saveable in saveables: 345 shapes = None 346 if reshape: 347 # Compute the shapes, let the restore op decide if and how to do 348 # the reshape. 349 shapes = [] 350 for spec in saveable.specs: 351 v = spec.tensor 352 shape = v.get_shape() 353 if not shape.is_fully_defined(): 354 shape = array_ops.shape(v) 355 shapes.append(shape) 356 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] 357 idx += len(saveable.specs) 358 assign_ops.append(saveable.restore(saveable_tensors, shapes)) 359 360 # Create a Noop that has control dependencies from all the updates. 361 return control_flow_ops.group(*assign_ops, name=name) 362 363 def _AddShardedRestoreOps(self, filename_tensor, per_device, 364 restore_sequentially, reshape): 365 """Add Ops to restore variables from multiple devices. 366 367 Args: 368 filename_tensor: Tensor for the path of the file to load. 369 per_device: A list of (device, SaveableObject) pairs, as returned by 370 _GroupByDevices(). 371 restore_sequentially: True if we want to restore variables sequentially 372 within a shard. 373 reshape: True if we want to reshape loaded tensors to the shape of the 374 corresponding variable. 375 376 Returns: 377 An Operation that restores the variables. 378 """ 379 sharded_restores = [] 380 for shard, (device, saveables) in enumerate(per_device): 381 with ops.device(device): 382 sharded_restores.append( 383 self._AddRestoreOps( 384 filename_tensor, 385 saveables, 386 restore_sequentially, 387 reshape, 388 preferred_shard=shard, 389 name="restore_shard")) 390 return control_flow_ops.group(*sharded_restores, name="restore_all") 391 392 def _GroupByDevices(self, saveables): 393 """Group Variable tensor slices per device. 394 395 TODO(touts): Make sure that all the devices found are on different 396 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. 397 It can happen if the devices are unspecified. 398 399 Args: 400 saveables: A list of BaseSaverBuilder.SaveableObject objects. 401 402 Returns: 403 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples. 404 The list is sorted by ascending device_name. 405 406 Raises: 407 ValueError: If the tensors of a saveable are on different devices. 408 """ 409 per_device = collections.defaultdict(lambda: []) 410 for saveable in saveables: 411 canonical_device = set( 412 pydev.canonical_name(spec.device) for spec in saveable.specs) 413 if len(canonical_device) != 1: 414 raise ValueError("All tensors of a saveable object must be " 415 "on the same device: %s" % saveable.name) 416 per_device[canonical_device.pop()].append(saveable) 417 return sorted(per_device.items(), key=lambda t: t[0]) 418 419 def build(self, 420 names_to_saveables, 421 reshape=False, 422 sharded=False, 423 max_to_keep=5, 424 keep_checkpoint_every_n_hours=10000.0, 425 name=None, 426 restore_sequentially=False, 427 filename="model"): 428 """Builds save/restore graph nodes or runs save/restore in eager mode. 429 430 Args: 431 names_to_saveables: A dictionary mapping name to a Variable or 432 SaveableObject. Each name will be associated with the corresponding 433 variable in the checkpoint. 434 reshape: If True, allow restoring parameters from a checkpoint that where 435 the parameters have a different shape. This is only needed when you try 436 to restore from a Dist-Belief checkpoint, and only some times. 437 sharded: If True, shard the checkpoints, one per device that has Variable 438 nodes. 439 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints 440 are created, old ones are deleted. If None or 0, no checkpoints are 441 deleted from the filesystem but only the last one is kept in the 442 `checkpoint` file. Presently the number is only roughly enforced. For 443 example in case of restarts more than max_to_keep checkpoints may be 444 kept. 445 keep_checkpoint_every_n_hours: How often checkpoints should be kept. 446 Defaults to 10,000 hours. 447 name: String. Optional name to use as a prefix when adding operations. 448 restore_sequentially: A Bool, which if true, causes restore of different 449 variables to happen sequentially within each device. 450 filename: If known at graph construction time, filename used for variable 451 loading/saving. If None, then the default name "model" will be used. 452 453 Returns: 454 A SaverDef proto. 455 456 Raises: 457 TypeError: If 'names_to_saveables' is not a dictionary mapping string 458 keys to variable Tensors. 459 ValueError: If any of the keys or values in 'names_to_saveables' is not 460 unique. 461 """ 462 return self._build_internal( 463 names_to_saveables=names_to_saveables, 464 reshape=reshape, 465 sharded=sharded, 466 max_to_keep=max_to_keep, 467 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 468 name=name, 469 restore_sequentially=restore_sequentially, 470 filename=filename) 471 472 def _build_internal(self, 473 names_to_saveables, 474 reshape=False, 475 sharded=False, 476 max_to_keep=5, 477 keep_checkpoint_every_n_hours=10000.0, 478 name=None, 479 restore_sequentially=False, 480 filename="model", 481 build_save=True, 482 build_restore=True): 483 """build() with option to only perform save and restore.""" 484 if not context.executing_eagerly() and (not build_save or 485 not build_restore): 486 raise ValueError("save and restore operations need to be built together " 487 " when eager execution is not enabled.") 488 489 saveables = saveable_object_util.validate_and_slice_inputs( 490 names_to_saveables) 491 if max_to_keep is None: 492 max_to_keep = 0 493 494 with ops.name_scope(name, "save", 495 [saveable.op for saveable in saveables]) as name: 496 # Add a placeholder string tensor for the filename. 497 filename_tensor = array_ops.placeholder_with_default( 498 filename or "model", shape=(), name="filename") 499 # Keep the name "Const" for backwards compatibility. 500 filename_tensor = array_ops.placeholder_with_default( 501 filename_tensor, shape=(), name="Const") 502 503 # Add the save ops. 504 if sharded: 505 per_device = self._GroupByDevices(saveables) 506 if build_save: 507 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) 508 if build_restore: 509 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, 510 restore_sequentially, reshape) 511 else: 512 if build_save: 513 save_tensor = self._AddSaveOps(filename_tensor, saveables) 514 if build_restore: 515 restore_op = self._AddRestoreOps(filename_tensor, saveables, 516 restore_sequentially, reshape) 517 518 # In the following use case, it's possible to have restore_ops be called 519 # something else: 520 # - Build inference graph and export a meta_graph. 521 # - Import the inference meta_graph 522 # - Extend the inference graph to a train graph. 523 # - Export a new meta_graph. 524 # Now the second restore_op will be called "restore_all_1". 525 # As such, comment out the assert for now until we know whether supporting 526 # such usage model makes sense. 527 # 528 # assert restore_op.name.endswith("restore_all"), restore_op.name 529 if context.executing_eagerly(): 530 # Store the tensor values to the tensor_names. 531 save_tensor_name = save_tensor.numpy() if build_save else "" 532 return saver_pb2.SaverDef( 533 filename_tensor_name=filename_tensor.numpy(), 534 save_tensor_name=save_tensor_name, 535 restore_op_name="", 536 max_to_keep=max_to_keep, 537 sharded=sharded, 538 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 539 version=self._write_version) 540 else: 541 graph = ops.get_default_graph() 542 # Do some sanity checking on collections containing 543 # PartitionedVariables. If a saved collection has a PartitionedVariable, 544 # the GraphDef needs to include concat ops to get the value (or there'll 545 # be a lookup error on load). 546 check_collection_list = graph.get_all_collection_keys() 547 for collection_type in check_collection_list: 548 for element in graph.get_collection(collection_type): 549 if isinstance(element, variables.PartitionedVariable): 550 try: 551 graph.get_operation_by_name(element.name) 552 except KeyError: 553 # Create a concat op for this PartitionedVariable. The user may 554 # not need it, but we'll try looking it up on MetaGraph restore 555 # since it's in a collection. 556 element.as_tensor() 557 return saver_pb2.SaverDef( 558 filename_tensor_name=filename_tensor.name, 559 save_tensor_name=save_tensor.name, 560 restore_op_name=restore_op.name, 561 max_to_keep=max_to_keep, 562 sharded=sharded, 563 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 564 version=self._write_version) 565 566 567class BulkSaverBuilder(BaseSaverBuilder): 568 """SaverBuilder with support for bulk restoring multiple saveables.""" 569 570 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 571 restore_sequentially): 572 573 # Ignored: bulk restore is internally sequential. 574 del restore_sequentially 575 restore_specs = [] 576 for saveable in saveables: 577 for spec in saveable.specs: 578 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 579 580 names, slices, dtypes = zip(*restore_specs) 581 # Load all tensors onto CPU 0 for compatibility with existing code. 582 with ops.device("cpu:0"): 583 return io_ops.restore_v2(filename_tensor, names, slices, dtypes) 584 585 586def _get_saver_or_default(): 587 """Returns the saver from SAVERS collection, or creates a default one. 588 589 This method is used by other members of the training module, such as 590 `Scaffold`, or `CheckpointSaverHook`. 591 592 Returns: 593 `Saver`. 594 595 Raises: 596 RuntimeError: If the SAVERS collection already has more than one items. 597 """ 598 collection_key = ops.GraphKeys.SAVERS 599 savers = ops.get_collection(collection_key) 600 if savers: 601 if len(savers) > 1: 602 raise RuntimeError( 603 "More than one item in collection {}. " 604 "Please indicate which one to use by passing it to the constructor." 605 .format(collection_key)) 606 return savers[0] 607 saver = Saver(sharded=True, allow_empty=True) 608 if saver is not None: 609 ops.add_to_collection(collection_key, saver) 610 return saver 611 612 613@tf_export(v1=["train.Saver"]) 614class Saver(object): 615 """Saves and restores variables. 616 617 See [Variables](https://tensorflow.org/guide/variables) 618 for an overview of variables, saving and restoring. 619 620 The `Saver` class adds ops to save and restore variables to and from 621 *checkpoints*. It also provides convenience methods to run these ops. 622 623 Checkpoints are binary files in a proprietary format which map variable names 624 to tensor values. The best way to examine the contents of a checkpoint is to 625 load it using a `Saver`. 626 627 Savers can automatically number checkpoint filenames with a provided counter. 628 This lets you keep multiple checkpoints at different steps while training a 629 model. For example you can number the checkpoint filenames with the training 630 step number. To avoid filling up disks, savers manage checkpoint files 631 automatically. For example, they can keep only the N most recent files, or 632 one checkpoint for every N hours of training. 633 634 You number checkpoint filenames by passing a value to the optional 635 `global_step` argument to `save()`: 636 637 ```python 638 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' 639 ... 640 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' 641 ``` 642 643 Additionally, optional arguments to the `Saver()` constructor let you control 644 the proliferation of checkpoint files on disk: 645 646 * `max_to_keep` indicates the maximum number of recent checkpoint files to 647 keep. As new files are created, older files are deleted. If None or 0, 648 no checkpoints are deleted from the filesystem but only the last one is 649 kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent 650 checkpoint files are kept.) 651 652 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent 653 `max_to_keep` checkpoint files, you might want to keep one checkpoint file 654 for every N hours of training. This can be useful if you want to later 655 analyze how a model progressed during a long training session. For 656 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep 657 one checkpoint file for every 2 hours of training. The default value of 658 10,000 hours effectively disables the feature. 659 660 Note that you still have to call the `save()` method to save the model. 661 Passing these arguments to the constructor will not save variables 662 automatically for you. 663 664 A training program that saves regularly looks like: 665 666 ```python 667 ... 668 # Create a saver. 669 saver = tf.compat.v1.train.Saver(...variables...) 670 # Launch the graph and train, saving the model every 1,000 steps. 671 sess = tf.compat.v1.Session() 672 for step in xrange(1000000): 673 sess.run(..training_op..) 674 if step % 1000 == 0: 675 # Append the step number to the checkpoint name: 676 saver.save(sess, 'my-model', global_step=step) 677 ``` 678 679 In addition to checkpoint files, savers keep a protocol buffer on disk with 680 the list of recent checkpoints. This is used to manage numbered checkpoint 681 files and by `latest_checkpoint()`, which makes it easy to discover the path 682 to the most recent checkpoint. That protocol buffer is stored in a file named 683 'checkpoint' next to the checkpoint files. 684 685 If you create several savers, you can specify a different filename for the 686 protocol buffer file in the call to `save()`. 687 """ 688 689 def __init__(self, 690 var_list=None, 691 reshape=False, 692 sharded=False, 693 max_to_keep=5, 694 keep_checkpoint_every_n_hours=10000.0, 695 name=None, 696 restore_sequentially=False, 697 saver_def=None, 698 builder=None, 699 defer_build=False, 700 allow_empty=False, 701 write_version=saver_pb2.SaverDef.V2, 702 pad_step_number=False, 703 save_relative_paths=False, 704 filename=None): 705 """Creates a `Saver`. 706 707 The constructor adds ops to save and restore variables. 708 709 `var_list` specifies the variables that will be saved and restored. It can 710 be passed as a `dict` or a list: 711 712 * A `dict` of names to variables: The keys are the names that will be 713 used to save or restore the variables in the checkpoint files. 714 * A list of variables: The variables will be keyed with their op name in 715 the checkpoint files. 716 717 For example: 718 719 ```python 720 v1 = tf.Variable(..., name='v1') 721 v2 = tf.Variable(..., name='v2') 722 723 # Pass the variables as a dict: 724 saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2}) 725 726 # Or pass them as a list. 727 saver = tf.compat.v1.train.Saver([v1, v2]) 728 # Passing a list is equivalent to passing a dict with the variable op names 729 # as keys: 730 saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]}) 731 ``` 732 733 Note: the newer `AutoTrackable` API is not supported by `Saver`. In this 734 case, the `tf.train.Checkpoint` class should be used. 735 736 The optional `reshape` argument, if `True`, allows restoring a variable from 737 a save file where the variable had a different shape, but the same number 738 of elements and type. This is useful if you have reshaped a variable and 739 want to reload it from an older checkpoint. 740 741 The optional `sharded` argument, if `True`, instructs the saver to shard 742 checkpoints per device. 743 744 Args: 745 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping 746 names to `SaveableObject`s. If `None`, defaults to the list of all 747 saveable objects. 748 reshape: If `True`, allows restoring parameters from a checkpoint where 749 the variables have a different shape. 750 sharded: If `True`, shard the checkpoints, one per device. 751 max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5. 752 keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to 753 10,000 hours. 754 name: String. Optional name to use as a prefix when adding operations. 755 restore_sequentially: A `Bool`, which if true, causes restore of different 756 variables to happen sequentially within each device. This can lower 757 memory usage when restoring very large models. 758 saver_def: Optional `SaverDef` proto to use instead of running the 759 builder. This is only useful for specialty code that wants to recreate a 760 `Saver` object for a previously built `Graph` that had a `Saver`. The 761 `saver_def` proto should be the one returned by the `as_saver_def()` 762 call of the `Saver` that was created for that `Graph`. 763 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. 764 Defaults to `BulkSaverBuilder()`. 765 defer_build: If `True`, defer adding the save and restore ops to the 766 `build()` call. In that case `build()` should be called before 767 finalizing the graph or using the saver. 768 allow_empty: If `False` (default) raise an error if there are no variables 769 in the graph. Otherwise, construct the saver anyway and make it a no-op. 770 write_version: controls what format to use when saving checkpoints. It 771 also affects certain filepath matching logic. The V2 format is the 772 recommended choice: it is much more optimized than V1 in terms of memory 773 required and latency incurred during restore. Regardless of this 774 flag, the Saver is able to restore from both V2 and V1 checkpoints. 775 pad_step_number: if True, pads the global step number in the checkpoint 776 filepaths to some fixed width (8 by default). This is turned off by 777 default. 778 save_relative_paths: If `True`, will write relative paths to the 779 checkpoint state file. This is needed if the user wants to copy the 780 checkpoint directory and reload from the copied directory. 781 filename: If known at graph construction time, filename used for variable 782 loading/saving. 783 784 Raises: 785 TypeError: If `var_list` is invalid. 786 ValueError: If any of the keys or values in `var_list` are not unique. 787 RuntimeError: If eager execution is enabled and`var_list` does not specify 788 a list of variables to save. 789 790 @compatibility(eager) 791 When eager execution is enabled, `var_list` must specify a `list` or `dict` 792 of variables to save. Otherwise, a `RuntimeError` will be raised. 793 794 Although Saver works in some cases when executing eagerly, it is 795 fragile. Please switch to `tf.train.Checkpoint` or 796 `tf.keras.Model.save_weights`, which perform a more robust object-based 797 saving. These APIs will load checkpoints written by `Saver`. 798 @end_compatibility 799 """ 800 if defer_build and var_list: 801 raise ValueError( 802 "If `var_list` is provided then build cannot be deferred. " 803 "Either set defer_build=False or var_list=None.") 804 if context.executing_eagerly(): 805 logging.warning( 806 "Saver is deprecated, please switch to tf.train.Checkpoint or " 807 "tf.keras.Model.save_weights for training checkpoints. When " 808 "executing eagerly variables do not necessarily have unique names, " 809 "and so the variable.name-based lookups Saver performs are " 810 "error-prone.") 811 if var_list is None: 812 raise RuntimeError( 813 "When eager execution is enabled, `var_list` must specify a list " 814 "or dict of variables to save") 815 self._var_list = var_list 816 self._reshape = reshape 817 self._sharded = sharded 818 self._max_to_keep = max_to_keep 819 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 820 self._name = name 821 self._restore_sequentially = restore_sequentially 822 self.saver_def = saver_def 823 self._builder = builder 824 self._is_built = False 825 self._allow_empty = allow_empty 826 self._is_empty = None 827 self._write_version = write_version 828 self._pad_step_number = pad_step_number 829 self._filename = filename 830 self._last_checkpoints = [] 831 self._checkpoints_to_be_deleted = [] 832 if context.executing_eagerly(): 833 self._next_checkpoint_time = ( 834 time.time() + self._keep_checkpoint_every_n_hours * 3600) 835 elif not defer_build: 836 self.build() 837 if self.saver_def: 838 self._check_saver_def() 839 self._write_version = self.saver_def.version 840 self._save_relative_paths = save_relative_paths 841 # For compatibility with object-based checkpoints, we may build a second 842 # Saver to read the renamed keys. 843 self._object_restore_saver = None 844 845 def build(self): 846 if context.executing_eagerly(): 847 raise RuntimeError("Use save/restore instead of build in eager mode.") 848 self._build(self._filename, build_save=True, build_restore=True) 849 850 def _build_eager(self, checkpoint_path, build_save, build_restore): 851 self._build( 852 checkpoint_path, build_save=build_save, build_restore=build_restore) 853 854 def _build(self, checkpoint_path, build_save, build_restore): 855 """Builds saver_def.""" 856 if not context.executing_eagerly(): 857 if self._is_built: 858 return 859 self._is_built = True 860 861 if not self.saver_def or context.executing_eagerly(): 862 if self._builder is None: 863 self._builder = BulkSaverBuilder(self._write_version) 864 865 if self._var_list is None: 866 # pylint: disable=protected-access 867 self._var_list = variables._all_saveable_objects() 868 if not self._var_list: 869 if self._allow_empty: 870 self._is_empty = True 871 return 872 else: 873 raise ValueError("No variables to save") 874 self._is_empty = False 875 876 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access 877 self._var_list, 878 reshape=self._reshape, 879 sharded=self._sharded, 880 max_to_keep=self._max_to_keep, 881 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, 882 name=self._name, 883 restore_sequentially=self._restore_sequentially, 884 filename=checkpoint_path, 885 build_save=build_save, 886 build_restore=build_restore) 887 elif self.saver_def and self._name: 888 # Since self._name is used as a name_scope by builder(), we are 889 # overloading the use of this field to represent the "import_scope" as 890 # well. 891 self.saver_def.filename_tensor_name = ops.prepend_name_scope( 892 self.saver_def.filename_tensor_name, self._name) 893 self.saver_def.save_tensor_name = ops.prepend_name_scope( 894 self.saver_def.save_tensor_name, self._name) 895 self.saver_def.restore_op_name = ops.prepend_name_scope( 896 self.saver_def.restore_op_name, self._name) 897 898 self._check_saver_def() 899 if not context.executing_eagerly(): 900 # Updates next checkpoint time. 901 # Set in __init__ when executing eagerly. 902 self._next_checkpoint_time = ( 903 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) 904 905 def _check_saver_def(self): 906 if not isinstance(self.saver_def, saver_pb2.SaverDef): 907 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % 908 self.saver_def) 909 if not context.executing_eagerly(): 910 if not self.saver_def.save_tensor_name: 911 raise ValueError("saver_def must specify the save_tensor_name: %s" % 912 str(self.saver_def)) 913 if not self.saver_def.restore_op_name: 914 raise ValueError("saver_def must specify the restore_op_name: %s" % 915 str(self.saver_def)) 916 917 def _CheckpointFilename(self, p): 918 """Returns the checkpoint filename given a `(filename, time)` pair. 919 920 Args: 921 p: (filename, time) pair. 922 923 Returns: 924 Checkpoint file name. 925 """ 926 name, _ = p 927 return name 928 929 def _RecordLastCheckpoint(self, latest_save_path): 930 """Manages the list of the latest checkpoints.""" 931 if not self.saver_def.max_to_keep: 932 return 933 # Remove first from list if the same name was used before. 934 for p in self._last_checkpoints: 935 if latest_save_path == self._CheckpointFilename(p): 936 self._last_checkpoints.remove(p) 937 # Append new path to list 938 self._last_checkpoints.append((latest_save_path, time.time())) 939 940 # If more than max_to_keep, remove oldest. 941 if len(self._last_checkpoints) > self.saver_def.max_to_keep: 942 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) 943 944 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): 945 """Deletes old checkpoints if necessary. 946 947 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are 948 over `max_to_keep`. They are going to be deleted. If 949 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint 950 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is 951 kept for every 0.5 hours of training; if `N` is 10, an additional 952 checkpoint is kept for every 10 hours of training. 953 954 Args: 955 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 956 """ 957 if self._checkpoints_to_be_deleted: 958 p = self._checkpoints_to_be_deleted.pop(0) 959 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we 960 # have reached N hours of training. 961 should_keep = p[1] > self._next_checkpoint_time 962 if should_keep: 963 self._next_checkpoint_time += ( 964 self.saver_def.keep_checkpoint_every_n_hours * 3600) 965 return 966 967 # Otherwise delete the files. 968 try: 969 checkpoint_management.remove_checkpoint( 970 self._CheckpointFilename(p), self.saver_def.version, 971 meta_graph_suffix) 972 except Exception as e: # pylint: disable=broad-except 973 logging.warning("Ignoring: %s", str(e)) 974 975 def as_saver_def(self): 976 """Generates a `SaverDef` representation of this saver. 977 978 Returns: 979 A `SaverDef` proto. 980 """ 981 return self.saver_def 982 983 def to_proto(self, export_scope=None): 984 """Converts this `Saver` to a `SaverDef` protocol buffer. 985 986 Args: 987 export_scope: Optional `string`. Name scope to remove. 988 989 Returns: 990 A `SaverDef` protocol buffer. 991 """ 992 if export_scope is None: 993 return self.saver_def 994 995 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and 996 self.saver_def.save_tensor_name.startswith(export_scope) and 997 self.saver_def.restore_op_name.startswith(export_scope)): 998 return None 999 1000 saver_def = saver_pb2.SaverDef() 1001 saver_def.CopyFrom(self.saver_def) 1002 saver_def.filename_tensor_name = ops.strip_name_scope( 1003 saver_def.filename_tensor_name, export_scope) 1004 saver_def.save_tensor_name = ops.strip_name_scope( 1005 saver_def.save_tensor_name, export_scope) 1006 saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name, 1007 export_scope) 1008 return saver_def 1009 1010 @staticmethod 1011 def from_proto(saver_def, import_scope=None): 1012 """Returns a `Saver` object created from `saver_def`. 1013 1014 Args: 1015 saver_def: a `SaverDef` protocol buffer. 1016 import_scope: Optional `string`. Name scope to use. 1017 1018 Returns: 1019 A `Saver` built from saver_def. 1020 """ 1021 return Saver(saver_def=saver_def, name=import_scope) 1022 1023 @property 1024 def last_checkpoints(self): 1025 """List of not-yet-deleted checkpoint filenames. 1026 1027 You can pass any of the returned values to `restore()`. 1028 1029 Returns: 1030 A list of checkpoint filenames, sorted from oldest to newest. 1031 """ 1032 return list(self._CheckpointFilename(p) for p in self._last_checkpoints) 1033 1034 def set_last_checkpoints(self, last_checkpoints): 1035 """DEPRECATED: Use set_last_checkpoints_with_time. 1036 1037 Sets the list of old checkpoint filenames. 1038 1039 Args: 1040 last_checkpoints: A list of checkpoint filenames. 1041 1042 Raises: 1043 AssertionError: If last_checkpoints is not a list. 1044 """ 1045 assert isinstance(last_checkpoints, list) 1046 # We use a timestamp of +inf so that this checkpoint will never be 1047 # deleted. This is both safe and backwards compatible to a previous 1048 # version of the code which used s[1] as the "timestamp". 1049 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] 1050 1051 def set_last_checkpoints_with_time(self, last_checkpoints_with_time): 1052 """Sets the list of old checkpoint filenames and timestamps. 1053 1054 Args: 1055 last_checkpoints_with_time: A list of tuples of checkpoint filenames and 1056 timestamps. 1057 1058 Raises: 1059 AssertionError: If last_checkpoints_with_time is not a list. 1060 """ 1061 assert isinstance(last_checkpoints_with_time, list) 1062 self._last_checkpoints = last_checkpoints_with_time 1063 1064 def recover_last_checkpoints(self, checkpoint_paths): 1065 """Recovers the internal saver state after a crash. 1066 1067 This method is useful for recovering the "self._last_checkpoints" state. 1068 1069 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files 1070 exist, use their mtime as the checkpoint timestamp. 1071 1072 Args: 1073 checkpoint_paths: a list of checkpoint paths. 1074 """ 1075 checkpoints_with_mtimes = [] 1076 for checkpoint_path in checkpoint_paths: 1077 try: 1078 mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path]) 1079 except errors.NotFoundError: 1080 # It's fine if some other thread/process is deleting some older 1081 # checkpoint concurrently. 1082 continue 1083 if mtime: 1084 checkpoints_with_mtimes.append((checkpoint_path, mtime[0])) 1085 self.set_last_checkpoints_with_time(checkpoints_with_mtimes) 1086 1087 def save(self, 1088 sess, 1089 save_path, 1090 global_step=None, 1091 latest_filename=None, 1092 meta_graph_suffix="meta", 1093 write_meta_graph=True, 1094 write_state=True, 1095 strip_default_attrs=False, 1096 save_debug_info=False): 1097 # pylint: disable=line-too-long 1098 """Saves variables. 1099 1100 This method runs the ops added by the constructor for saving variables. 1101 It requires a session in which the graph was launched. The variables to 1102 save must also have been initialized. 1103 1104 The method returns the path prefix of the newly created checkpoint files. 1105 This string can be passed directly to a call to `restore()`. 1106 1107 Args: 1108 sess: A Session to use to save the variables. 1109 save_path: String. Prefix of filenames created for the checkpoint. 1110 global_step: If provided the global step number is appended to `save_path` 1111 to create the checkpoint filenames. The optional argument can be a 1112 `Tensor`, a `Tensor` name or an integer. 1113 latest_filename: Optional name for the protocol buffer file that will 1114 contains the list of most recent checkpoints. That file, kept in the 1115 same directory as the checkpoint files, is automatically managed by the 1116 saver to keep track of recent checkpoints. Defaults to 'checkpoint'. 1117 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1118 write_meta_graph: `Boolean` indicating whether or not to write the meta 1119 graph file. 1120 write_state: `Boolean` indicating whether or not to write the 1121 `CheckpointStateProto`. 1122 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1123 removed from the NodeDefs. For a detailed guide, see 1124 [Stripping Default-Valued 1125 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1126 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1127 which in the same directory of save_path and with `_debug` added before 1128 the file extension. This is only enabled when `write_meta_graph` is 1129 `True` 1130 1131 Returns: 1132 A string: path prefix used for the checkpoint files. If the saver is 1133 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' 1134 is the number of shards created. 1135 If the saver is empty, returns None. 1136 1137 Raises: 1138 TypeError: If `sess` is not a `Session`. 1139 ValueError: If `latest_filename` contains path components, or if it 1140 collides with `save_path`. 1141 RuntimeError: If save and restore ops weren't built. 1142 """ 1143 # pylint: enable=line-too-long 1144 if not self._is_built and not context.executing_eagerly(): 1145 raise RuntimeError( 1146 "`build()` should be called before save if defer_build==True") 1147 if latest_filename is None: 1148 latest_filename = "checkpoint" 1149 if self._write_version != saver_pb2.SaverDef.V2: 1150 logging.warning("*******************************************************") 1151 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") 1152 logging.warning("Consider switching to the more efficient V2 format:") 1153 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") 1154 logging.warning("now on by default.") 1155 logging.warning("*******************************************************") 1156 1157 if os.path.split(latest_filename)[0]: 1158 raise ValueError("'latest_filename' must not contain path components") 1159 1160 save_path = compat.as_str(save_path) 1161 if global_step is not None: 1162 if not isinstance(global_step, compat.integral_types): 1163 global_step = training_util.global_step(sess, global_step) 1164 checkpoint_file = "%s-%d" % (save_path, global_step) 1165 if self._pad_step_number: 1166 # Zero-pads the step numbers, so that they are sorted when listed. 1167 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) 1168 else: 1169 checkpoint_file = save_path 1170 if os.path.basename(save_path) == latest_filename and not self._sharded: 1171 # Guard against collision between data file and checkpoint state file. 1172 raise ValueError( 1173 "'latest_filename' collides with 'save_path': '%s' and '%s'" % 1174 (latest_filename, save_path)) 1175 1176 if (not context.executing_eagerly() and 1177 not isinstance(sess, session.SessionInterface)): 1178 raise TypeError("'sess' must be a Session; %s" % sess) 1179 1180 save_path_parent = os.path.dirname(save_path) 1181 if not self._is_empty: 1182 try: 1183 if context.executing_eagerly(): 1184 self._build_eager( 1185 checkpoint_file, build_save=True, build_restore=False) 1186 model_checkpoint_path = self.saver_def.save_tensor_name 1187 else: 1188 model_checkpoint_path = sess.run( 1189 self.saver_def.save_tensor_name, 1190 {self.saver_def.filename_tensor_name: checkpoint_file}) 1191 1192 model_checkpoint_path = compat.as_str(model_checkpoint_path) 1193 if write_state: 1194 self._RecordLastCheckpoint(model_checkpoint_path) 1195 checkpoint_management.update_checkpoint_state_internal( 1196 save_dir=save_path_parent, 1197 model_checkpoint_path=model_checkpoint_path, 1198 all_model_checkpoint_paths=self.last_checkpoints, 1199 latest_filename=latest_filename, 1200 save_relative_paths=self._save_relative_paths) 1201 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) 1202 except (errors.FailedPreconditionError, errors.NotFoundError) as exc: 1203 if not gfile.IsDirectory(save_path_parent): 1204 exc = ValueError( 1205 "Parent directory of {} doesn't exist, can't save.".format( 1206 save_path)) 1207 raise exc 1208 1209 if write_meta_graph: 1210 meta_graph_filename = checkpoint_management.meta_graph_filename( 1211 checkpoint_file, meta_graph_suffix=meta_graph_suffix) 1212 if not context.executing_eagerly(): 1213 with sess.graph.as_default(): 1214 self.export_meta_graph( 1215 meta_graph_filename, 1216 strip_default_attrs=strip_default_attrs, 1217 save_debug_info=save_debug_info) 1218 1219 if self._is_empty: 1220 return None 1221 else: 1222 return model_checkpoint_path 1223 1224 def export_meta_graph(self, 1225 filename=None, 1226 collection_list=None, 1227 as_text=False, 1228 export_scope=None, 1229 clear_devices=False, 1230 clear_extraneous_savers=False, 1231 strip_default_attrs=False, 1232 save_debug_info=False): 1233 # pylint: disable=line-too-long 1234 """Writes `MetaGraphDef` to save_path/filename. 1235 1236 Args: 1237 filename: Optional meta_graph filename including the path. 1238 collection_list: List of string keys to collect. 1239 as_text: If `True`, writes the meta_graph as an ASCII proto. 1240 export_scope: Optional `string`. Name scope to remove. 1241 clear_devices: Whether or not to clear the device field for an `Operation` 1242 or `Tensor` during export. 1243 clear_extraneous_savers: Remove any Saver-related information from the 1244 graph (both Save/Restore ops and SaverDefs) that are not associated with 1245 this Saver. 1246 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1247 removed from the NodeDefs. For a detailed guide, see 1248 [Stripping Default-Valued 1249 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1250 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1251 which in the same directory of filename and with `_debug` added before 1252 the file extension. 1253 1254 Returns: 1255 A `MetaGraphDef` proto. 1256 """ 1257 # pylint: enable=line-too-long 1258 return export_meta_graph( 1259 filename=filename, 1260 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 1261 saver_def=self.saver_def, 1262 collection_list=collection_list, 1263 as_text=as_text, 1264 export_scope=export_scope, 1265 clear_devices=clear_devices, 1266 clear_extraneous_savers=clear_extraneous_savers, 1267 strip_default_attrs=strip_default_attrs, 1268 save_debug_info=save_debug_info) 1269 1270 def restore(self, sess, save_path): 1271 """Restores previously saved variables. 1272 1273 This method runs the ops added by the constructor for restoring variables. 1274 It requires a session in which the graph was launched. The variables to 1275 restore do not have to have been initialized, as restoring is itself a way 1276 to initialize variables. 1277 1278 The `save_path` argument is typically a value previously returned from a 1279 `save()` call, or a call to `latest_checkpoint()`. 1280 1281 Args: 1282 sess: A `Session` to use to restore the parameters. None in eager mode. 1283 save_path: Path where parameters were previously saved. 1284 1285 Raises: 1286 ValueError: If save_path is None or not a valid checkpoint. 1287 """ 1288 if self._is_empty: 1289 return 1290 if save_path is None: 1291 raise ValueError("Can't load save_path when it is None.") 1292 1293 checkpoint_prefix = compat.as_text(save_path) 1294 if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): 1295 raise ValueError("The passed save_path is not a valid checkpoint: " + 1296 checkpoint_prefix) 1297 1298 logging.info("Restoring parameters from %s", checkpoint_prefix) 1299 try: 1300 if context.executing_eagerly(): 1301 self._build_eager(save_path, build_save=False, build_restore=True) 1302 else: 1303 sess.run(self.saver_def.restore_op_name, 1304 {self.saver_def.filename_tensor_name: save_path}) 1305 except errors.NotFoundError as err: 1306 # There are three common conditions that might cause this error: 1307 # 0. The file is missing. We ignore here, as this is checked above. 1308 # 1. This is an object-based checkpoint trying name-based loading. 1309 # 2. The graph has been altered and a variable or other name is missing. 1310 1311 # 1. The checkpoint would not be loaded successfully as is. Try to parse 1312 # it as an object-based checkpoint. 1313 try: 1314 names_to_keys = object_graph_key_mapping(save_path) 1315 except errors.NotFoundError: 1316 # 2. This is not an object-based checkpoint, which likely means there 1317 # is a graph mismatch. Re-raise the original error with 1318 # a helpful message (b/110263146) 1319 raise _wrap_restore_error_with_msg( 1320 err, "a Variable name or other graph key that is missing") 1321 1322 # This is an object-based checkpoint. We'll print a warning and then do 1323 # the restore. 1324 logging.warning( 1325 "Restoring an object-based checkpoint using a name-based saver. This " 1326 "may be somewhat fragile, and will re-build the Saver. Instead, " 1327 "consider loading object-based checkpoints using " 1328 "tf.train.Checkpoint().") 1329 self._object_restore_saver = saver_from_object_based_checkpoint( 1330 checkpoint_path=save_path, 1331 var_list=self._var_list, 1332 builder=self._builder, 1333 names_to_keys=names_to_keys, 1334 cached_saver=self._object_restore_saver) 1335 self._object_restore_saver.restore(sess=sess, save_path=save_path) 1336 except errors.InvalidArgumentError as err: 1337 # There is a mismatch between the graph and the checkpoint being loaded. 1338 # We add a more reasonable error message here to help users (b/110263146) 1339 raise _wrap_restore_error_with_msg( 1340 err, "a mismatch between the current graph and the graph") 1341 1342 @staticmethod 1343 def _add_collection_def(meta_graph_def, key, export_scope=None): 1344 """Adds a collection to MetaGraphDef protocol buffer. 1345 1346 Args: 1347 meta_graph_def: MetaGraphDef protocol buffer. 1348 key: One of the GraphKeys or user-defined string. 1349 export_scope: Optional `string`. Name scope to remove. 1350 """ 1351 meta_graph.add_collection_def( 1352 meta_graph_def, key, export_scope=export_scope) 1353 1354 1355@tf_export(v1=["train.import_meta_graph"]) 1356def import_meta_graph(meta_graph_or_file, 1357 clear_devices=False, 1358 import_scope=None, 1359 **kwargs): 1360 """Recreates a Graph saved in a `MetaGraphDef` proto. 1361 1362 This function takes a `MetaGraphDef` protocol buffer as input. If 1363 the argument is a file containing a `MetaGraphDef` protocol buffer , 1364 it constructs a protocol buffer from the file content. The function 1365 then adds all the nodes from the `graph_def` field to the 1366 current graph, recreates all the collections, and returns a saver 1367 constructed from the `saver_def` field. 1368 1369 In combination with `export_meta_graph()`, this function can be used to 1370 1371 * Serialize a graph along with other Python objects such as `QueueRunner`, 1372 `Variable` into a `MetaGraphDef`. 1373 1374 * Restart training from a saved graph and checkpoints. 1375 1376 * Run inference from a saved graph and checkpoints. 1377 1378 ```Python 1379 ... 1380 # Create a saver. 1381 saver = tf.compat.v1.train.Saver(...variables...) 1382 # Remember the training_op we want to run by adding it to a collection. 1383 tf.compat.v1.add_to_collection('train_op', train_op) 1384 sess = tf.compat.v1.Session() 1385 for step in xrange(1000000): 1386 sess.run(train_op) 1387 if step % 1000 == 0: 1388 # Saves checkpoint, which by default also exports a meta_graph 1389 # named 'my-model-global_step.meta'. 1390 saver.save(sess, 'my-model', global_step=step) 1391 ``` 1392 1393 Later we can continue training from this saved `meta_graph` without building 1394 the model from scratch. 1395 1396 ```Python 1397 with tf.Session() as sess: 1398 new_saver = 1399 tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 1400 new_saver.restore(sess, 'my-save-dir/my-model-10000') 1401 # tf.get_collection() returns a list. In this example we only want 1402 # the first one. 1403 train_op = tf.get_collection('train_op')[0] 1404 for step in xrange(1000000): 1405 sess.run(train_op) 1406 ``` 1407 1408 NOTE: Restarting training from saved `meta_graph` only works if the 1409 device assignments have not changed. 1410 1411 Example: 1412 Variables, placeholders, and independent operations can also be stored, as 1413 shown in the following example. 1414 1415 ```Python 1416 # Saving contents and operations. 1417 v1 = tf.placeholder(tf.float32, name="v1") 1418 v2 = tf.placeholder(tf.float32, name="v2") 1419 v3 = tf.math.multiply(v1, v2) 1420 vx = tf.Variable(10.0, name="vx") 1421 v4 = tf.add(v3, vx, name="v4") 1422 saver = tf.train.Saver([vx]) 1423 sess = tf.Session() 1424 sess.run(tf.global_variables_initializer()) 1425 sess.run(vx.assign(tf.add(vx, vx))) 1426 result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) 1427 print(result) 1428 saver.save(sess, "./model_ex1") 1429 ``` 1430 1431 Later this model can be restored and contents loaded. 1432 1433 ```Python 1434 # Restoring variables and running operations. 1435 saver = tf.train.import_meta_graph("./model_ex1.meta") 1436 sess = tf.Session() 1437 saver.restore(sess, "./model_ex1") 1438 result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 1439 print(result) 1440 ``` 1441 1442 Args: 1443 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 1444 the path) containing a `MetaGraphDef`. 1445 clear_devices: Whether or not to clear the device field for an `Operation` 1446 or `Tensor` during import. 1447 import_scope: Optional `string`. Name scope to add. Only used when 1448 initializing from protocol buffer. 1449 **kwargs: Optional keyed arguments. 1450 1451 Returns: 1452 A saver constructed from `saver_def` in `MetaGraphDef` or None. 1453 1454 A None value is returned if no variables exist in the `MetaGraphDef` 1455 (i.e., there are no variables to restore). 1456 1457 Raises: 1458 RuntimeError: If called with eager execution enabled. 1459 1460 @compatibility(eager) 1461 Exporting/importing meta graphs is not supported. No graph exists when eager 1462 execution is enabled. 1463 @end_compatibility 1464 """ # pylint: disable=g-doc-exception 1465 return _import_meta_graph_with_return_elements(meta_graph_or_file, 1466 clear_devices, import_scope, 1467 **kwargs)[0] 1468 1469 1470def _import_meta_graph_with_return_elements(meta_graph_or_file, 1471 clear_devices=False, 1472 import_scope=None, 1473 return_elements=None, 1474 **kwargs): 1475 """Import MetaGraph, and return both a saver and returned elements.""" 1476 if context.executing_eagerly(): 1477 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1478 "eager execution is enabled. No graph exists when eager " 1479 "execution is enabled.") 1480 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 1481 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) 1482 else: 1483 meta_graph_def = meta_graph_or_file 1484 1485 imported_vars, imported_return_elements = ( 1486 meta_graph.import_scoped_meta_graph_with_return_elements( 1487 meta_graph_def, 1488 clear_devices=clear_devices, 1489 import_scope=import_scope, 1490 return_elements=return_elements, 1491 **kwargs)) 1492 1493 saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1494 imported_vars) 1495 return saver, imported_return_elements 1496 1497 1498def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1499 imported_vars): 1500 """Return a saver for restoring variable values to an imported MetaGraph.""" 1501 if meta_graph_def.HasField("saver_def"): 1502 # Infer the scope that is prepended by `import_scoped_meta_graph`. 1503 scope = import_scope 1504 var_names = list(imported_vars.keys()) 1505 if var_names: 1506 sample_key = var_names[0] 1507 sample_var = imported_vars[sample_key] 1508 scope = sample_var.name[:-len(sample_key)] 1509 1510 return Saver(saver_def=meta_graph_def.saver_def, name=scope) 1511 else: 1512 if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access 1513 # Return the default saver instance for all graph variables. 1514 return Saver() 1515 else: 1516 # If no graph variables exist, then a Saver cannot be constructed. 1517 logging.info("Saver not created because there are no variables in the" 1518 " graph to restore") 1519 return None 1520 1521 1522@tf_export(v1=["train.export_meta_graph"]) 1523def export_meta_graph(filename=None, 1524 meta_info_def=None, 1525 graph_def=None, 1526 saver_def=None, 1527 collection_list=None, 1528 as_text=False, 1529 graph=None, 1530 export_scope=None, 1531 clear_devices=False, 1532 clear_extraneous_savers=False, 1533 strip_default_attrs=False, 1534 save_debug_info=False, 1535 **kwargs): 1536 # pylint: disable=line-too-long 1537 """Returns `MetaGraphDef` proto. 1538 1539 Optionally writes it to filename. 1540 1541 This function exports the graph, saver, and collection objects into 1542 `MetaGraphDef` protocol buffer with the intention of it being imported 1543 at a later time or location to restart training, run inference, or be 1544 a subgraph. 1545 1546 Args: 1547 filename: Optional filename including the path for writing the generated 1548 `MetaGraphDef` protocol buffer. 1549 meta_info_def: `MetaInfoDef` protocol buffer. 1550 graph_def: `GraphDef` protocol buffer. 1551 saver_def: `SaverDef` protocol buffer. 1552 collection_list: List of string keys to collect. 1553 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 1554 graph: The `Graph` to export. If `None`, use the default graph. 1555 export_scope: Optional `string`. Name scope under which to extract the 1556 subgraph. The scope name will be striped from the node definitions for 1557 easy import later into new name scopes. If `None`, the whole graph is 1558 exported. graph_def and export_scope cannot both be specified. 1559 clear_devices: Whether or not to clear the device field for an `Operation` 1560 or `Tensor` during export. 1561 clear_extraneous_savers: Remove any Saver-related information from the graph 1562 (both Save/Restore ops and SaverDefs) that are not associated with the 1563 provided SaverDef. 1564 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1565 removed from the NodeDefs. For a detailed guide, see 1566 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1567 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1568 which in the same directory of filename and with `_debug` added before the 1569 file extend. 1570 **kwargs: Optional keyed arguments. 1571 1572 Returns: 1573 A `MetaGraphDef` proto. 1574 1575 Raises: 1576 ValueError: When the `GraphDef` is larger than 2GB. 1577 RuntimeError: If called with eager execution enabled. 1578 1579 @compatibility(eager) 1580 Exporting/importing meta graphs is not supported unless both `graph_def` and 1581 `graph` are provided. No graph exists when eager execution is enabled. 1582 @end_compatibility 1583 """ 1584 # pylint: enable=line-too-long 1585 if context.executing_eagerly() and not (graph_def is not None and 1586 graph is not None): 1587 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1588 "eager execution is enabled. No graph exists when eager " 1589 "execution is enabled.") 1590 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 1591 filename=filename, 1592 meta_info_def=meta_info_def, 1593 graph_def=graph_def, 1594 saver_def=saver_def, 1595 collection_list=collection_list, 1596 as_text=as_text, 1597 graph=graph, 1598 export_scope=export_scope, 1599 clear_devices=clear_devices, 1600 clear_extraneous_savers=clear_extraneous_savers, 1601 strip_default_attrs=strip_default_attrs, 1602 save_debug_info=save_debug_info, 1603 **kwargs) 1604 return meta_graph_def 1605 1606 1607def _wrap_restore_error_with_msg(err, extra_verbiage): 1608 err_msg = ("Restoring from checkpoint failed. This is most likely " 1609 "due to {} from the checkpoint. Please ensure that you " 1610 "have not altered the graph expected based on the checkpoint. " 1611 "Original error:\n\n{}").format(extra_verbiage, err.message) 1612 return err.__class__(err.node_def, err.op, err_msg) 1613 1614 1615ops.register_proto_function( 1616 ops.GraphKeys.SAVERS, 1617 proto_type=saver_pb2.SaverDef, 1618 to_proto=Saver.to_proto, 1619 from_proto=Saver.from_proto) 1620 1621 1622def object_graph_key_mapping(checkpoint_path): 1623 """Return name to key mappings from the checkpoint. 1624 1625 Args: 1626 checkpoint_path: string, path to object-based checkpoint 1627 1628 Returns: 1629 Dictionary mapping tensor names to checkpoint keys. 1630 """ 1631 reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) 1632 object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY) 1633 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1634 object_graph_proto.ParseFromString(object_graph_string) 1635 names_to_keys = {} 1636 for node in object_graph_proto.nodes: 1637 for attribute in node.attributes: 1638 names_to_keys[attribute.full_name] = attribute.checkpoint_key 1639 return names_to_keys 1640 1641 1642def saver_from_object_based_checkpoint(checkpoint_path, 1643 var_list=None, 1644 builder=None, 1645 names_to_keys=None, 1646 cached_saver=None): 1647 """Return a `Saver` which reads from an object-based checkpoint. 1648 1649 This function validates that all variables in the variables list are remapped 1650 in the object-based checkpoint (or `names_to_keys` dict if provided). A 1651 saver will be created with the list of remapped variables. 1652 1653 The `cached_saver` argument allows the user to pass in a previously created 1654 saver, so multiple `saver.restore()` calls don't pollute the graph when graph 1655 building. This assumes that keys are consistent, meaning that the 1656 1) `checkpoint_path` checkpoint, and 1657 2) checkpoint used to create the `cached_saver` 1658 are the same type of object-based checkpoint. If this argument is set, this 1659 function will simply validate that all variables have been remapped by the 1660 checkpoint at `checkpoint_path`. 1661 1662 Note that in general, `tf.train.Checkpoint` should be used to restore/save an 1663 object-based checkpoint. 1664 1665 Args: 1666 checkpoint_path: string, path to object-based checkpoint 1667 var_list: list of `Variables` that appear in the checkpoint. If `None`, 1668 `var_list` will be set to all saveable objects. 1669 builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder` 1670 will be created. 1671 names_to_keys: dict mapping string tensor names to checkpoint keys. If 1672 `None`, this dict will be generated from the checkpoint file. 1673 cached_saver: Cached `Saver` object with remapped variables. 1674 1675 Returns: 1676 `Saver` with remapped variables for reading from an object-based checkpoint. 1677 1678 Raises: 1679 ValueError if the checkpoint provided is not an object-based checkpoint. 1680 NotFoundError: If one of the variables in `var_list` can not be found in the 1681 checkpoint. This could mean the checkpoint or `names_to_keys` mapping is 1682 missing the variable. 1683 """ 1684 if names_to_keys is None: 1685 try: 1686 names_to_keys = object_graph_key_mapping(checkpoint_path) 1687 except errors.NotFoundError: 1688 raise ValueError("Checkpoint in %s not an object-based checkpoint." % 1689 checkpoint_path) 1690 if var_list is None: 1691 var_list = variables._all_saveable_objects() # pylint: disable=protected-access 1692 if builder is None: 1693 builder = BulkSaverBuilder() 1694 1695 saveables = saveable_object_util.validate_and_slice_inputs(var_list) 1696 current_names = set() 1697 for saveable in saveables: 1698 for spec in saveable.specs: 1699 current_names.add(spec.name) 1700 previous_names = set(names_to_keys.keys()) 1701 missing_names = current_names - previous_names 1702 if missing_names: 1703 extra_names = previous_names - current_names 1704 intersecting_names = previous_names.intersection(current_names) 1705 raise errors.NotFoundError( 1706 None, 1707 None, 1708 message=( 1709 "\n\nExisting variables not in the checkpoint: %s\n\n" 1710 "Variables names when this checkpoint was written which don't " 1711 "exist now: %s\n\n" 1712 "(%d variable name(s) did match)\n\n" 1713 "Could not find some variables in the checkpoint (see names " 1714 "above). Saver was attempting to load an object-based checkpoint " 1715 "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) " 1716 "using variable names. If the checkpoint was written with eager " 1717 "execution enabled, it's possible that variable names have " 1718 "changed (for example missing a '_1' suffix). It's also " 1719 "possible that there are new variables which did not exist " 1720 "when the checkpoint was written. You can construct a " 1721 "Saver(var_list=...) with only the variables which previously " 1722 "existed, and if variable names have changed you may need to " 1723 "make this a dictionary with the old names as keys. If you're " 1724 "using an Estimator, you'll need to return a tf.train.Saver " 1725 "inside a tf.train.Scaffold from your model_fn.") % 1726 (", ".join(sorted(missing_names)), ", ".join( 1727 sorted(extra_names)), len(intersecting_names))) 1728 for saveable in saveables: 1729 for spec in saveable.specs: 1730 spec.name = names_to_keys[spec.name] 1731 if cached_saver is None: 1732 return Saver(saveables) 1733 return cached_saver 1734