1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables. 18 19Symbols in this file are deprecated. See replacements in 20tensorflow/python/training/trackable and tensorflow/python/training/saving. 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27import os.path 28import threading 29import time 30 31import numpy as np 32from tensorflow.core.protobuf import meta_graph_pb2 33from tensorflow.core.protobuf import saver_pb2 34from tensorflow.core.protobuf import trackable_object_graph_pb2 35from tensorflow.python.client import session 36from tensorflow.python.eager import context 37from tensorflow.python.framework import constant_op 38from tensorflow.python.framework import device as pydev 39from tensorflow.python.framework import errors 40from tensorflow.python.framework import meta_graph 41from tensorflow.python.framework import ops 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import gen_io_ops 45from tensorflow.python.ops import io_ops 46from tensorflow.python.ops import string_ops 47from tensorflow.python.ops import variables 48from tensorflow.python.platform import gfile 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.saved_model.pywrap_saved_model import metrics 51from tensorflow.python.training import checkpoint_management 52from tensorflow.python.training import py_checkpoint_reader 53from tensorflow.python.training import training_util 54from tensorflow.python.training.saving import saveable_object 55from tensorflow.python.training.saving import saveable_object_util 56from tensorflow.python.training.tracking import base as trackable 57from tensorflow.python.util import compat 58from tensorflow.python.util.tf_export import tf_export 59 60# TODO(allenl): Remove these aliases once all users are migrated off. 61get_checkpoint_state = checkpoint_management.get_checkpoint_state 62update_checkpoint_state = checkpoint_management.update_checkpoint_state 63generate_checkpoint_state_proto = ( 64 checkpoint_management.generate_checkpoint_state_proto) 65latest_checkpoint = checkpoint_management.latest_checkpoint 66checkpoint_exists = checkpoint_management.checkpoint_exists 67get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes 68remove_checkpoint = checkpoint_management.remove_checkpoint 69 70 71# Captures the timestamp of the first Saver object instantiation or end of a 72# save operation. Can be accessed by multiple Saver instances. 73_END_TIME_OF_LAST_WRITE = None 74_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock() 75 76# API label for cell name used in checkpoint metrics. 77_SAVER_LABEL = "saver_v1" 78 79 80def _get_duration_microseconds(start_time_seconds, end_time_seconds): 81 if end_time_seconds < start_time_seconds: 82 # Avoid returning negative value in case of clock skew. 83 return 0 84 return round((end_time_seconds - start_time_seconds) * 1000000) 85 86 87class BaseSaverBuilder(object): 88 """Base class for Savers. 89 90 Can be extended to create different Ops. 91 """ 92 93 SaveSpec = saveable_object.SaveSpec 94 SaveableObject = saveable_object.SaveableObject 95 96 # Aliases for code which was moved but still has lots of users. 97 VariableSaveable = saveable_object_util.ReferenceVariableSaveable 98 ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable 99 100 def __init__(self, write_version=saver_pb2.SaverDef.V2): 101 self._write_version = write_version 102 103 def save_op(self, filename_tensor, saveables): 104 """Create an Op to save 'saveables'. 105 106 This is intended to be overridden by subclasses that want to generate 107 different Ops. 108 109 Args: 110 filename_tensor: String Tensor. 111 saveables: A list of BaseSaverBuilder.SaveableObject objects. 112 113 Returns: 114 An Operation that save the variables. 115 116 Raises: 117 RuntimeError: (implementation detail) if "self._write_version" is an 118 unexpected value. 119 """ 120 # pylint: disable=protected-access 121 tensor_names = [] 122 tensors = [] 123 tensor_slices = [] 124 for saveable in saveables: 125 for spec in saveable.specs: 126 tensor_names.append(spec.name) 127 tensors.append(spec.tensor) 128 tensor_slices.append(spec.slice_spec) 129 if self._write_version == saver_pb2.SaverDef.V1: 130 return io_ops._save( 131 filename=filename_tensor, 132 tensor_names=tensor_names, 133 tensors=tensors, 134 tensor_slices=tensor_slices) 135 elif self._write_version == saver_pb2.SaverDef.V2: 136 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix 137 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". 138 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, 139 tensors) 140 else: 141 raise RuntimeError("Unexpected write_version: " + self._write_version) 142 143 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 144 restore_sequentially): 145 """Restore all tensors contained in saveables. 146 147 By default, this issues separate calls to `restore_op` for each saveable. 148 Subclasses may override to load multiple saveables in a single call. 149 150 Args: 151 filename_tensor: String Tensor. 152 saveables: List of BaseSaverBuilder.SaveableObject objects. 153 preferred_shard: Int. Shard to open first when loading a sharded file. 154 restore_sequentially: Unused. Bool. If true, each restore is sequential. 155 156 Returns: 157 A list of Tensors resulting from reading 'saveable' from 158 'filename'. 159 160 """ 161 del restore_sequentially 162 all_tensors = [] 163 for saveable in saveables: 164 if saveable.device: 165 device = saveable_object_util.set_cpu0(saveable.device) 166 else: 167 device = None 168 with ops.device(device): 169 all_tensors.extend( 170 self.restore_op(filename_tensor, saveable, preferred_shard)) 171 return all_tensors 172 173 # pylint: disable=unused-argument 174 def restore_op(self, filename_tensor, saveable, preferred_shard): 175 """Create ops to restore 'saveable'. 176 177 This is intended to be overridden by subclasses that want to generate 178 different Ops. 179 180 Args: 181 filename_tensor: String Tensor. 182 saveable: A BaseSaverBuilder.SaveableObject object. 183 preferred_shard: Int. Shard to open first when loading a sharded file. 184 185 Returns: 186 A list of Tensors resulting from reading 'saveable' from 187 'filename'. 188 """ 189 # pylint: disable=protected-access 190 tensors = [] 191 for spec in saveable.specs: 192 tensors.append( 193 io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec], 194 [spec.dtype])[0]) 195 196 return tensors 197 198 # pylint: enable=unused-argument 199 200 def sharded_filename(self, filename_tensor, shard, num_shards): 201 """Append sharding information to a filename. 202 203 Args: 204 filename_tensor: A string tensor. 205 shard: Integer. The shard for the filename. 206 num_shards: An int Tensor for the number of shards. 207 208 Returns: 209 A string tensor. 210 """ 211 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 212 213 def _AddSaveOps(self, filename_tensor, saveables): 214 """Add ops to save variables that are on the same shard. 215 216 Args: 217 filename_tensor: String Tensor. 218 saveables: A list of SaveableObject objects. 219 220 Returns: 221 A tensor with the filename used to save. 222 """ 223 save = self.save_op(filename_tensor, saveables) 224 return control_flow_ops.with_dependencies([save], filename_tensor) 225 226 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): 227 """Add ops to save the params per shard, for the V2 format. 228 229 Note that the sharded save procedure for the V2 format is different from 230 V1: there is a special "merge" step that merges the small metadata produced 231 from each device. 232 233 Args: 234 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*, 235 but as a prefix of a V2 checkpoint; 236 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as 237 returned by _GroupByDevices(). 238 239 Returns: 240 An op to save the variables, which, when evaluated, returns the prefix 241 "<user-fed prefix>" only and does not include the sharded spec suffix. 242 """ 243 # IMPLEMENTATION DETAILS: most clients should skip. 244 # 245 # Suffix for any well-formed "checkpoint_prefix", when sharded. 246 # Transformations: 247 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 248 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>. 249 # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it 250 # * Otherwise _temp/part is appended which is normalized relative to the OS 251 # Example: 252 # During runtime, a temporary directory is first created, which contains 253 # files 254 # 255 # <train dir>/myckpt_temp/ 256 # part-?????-of-?????{.index, .data-00000-of-00001} 257 # 258 # Before .save() finishes, they will be (hopefully, atomically) renamed to 259 # 260 # <train dir>/ 261 # myckpt{.index, .data-?????-of-?????} 262 # 263 # Filesystems with eventual consistency (such as S3), don't need a 264 # temporary location. Using a temporary directory in those cases might 265 # cause situations where files are not available during copy. 266 # 267 # Users only need to interact with the user-specified prefix, which is 268 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 269 # prefix directly, instead of any physical pathname. (On failure and 270 # subsequent restore, an outdated and orphaned temporary directory can be 271 # safely removed.) 272 with ops.device("CPU"): 273 _SHARDED_SUFFIX = array_ops.where( 274 string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"), 275 constant_op.constant(".part"), 276 constant_op.constant(os.path.normpath("_temp/part"))) 277 tmp_checkpoint_prefix = string_ops.string_join( 278 [checkpoint_prefix, _SHARDED_SUFFIX]) 279 280 num_shards = len(per_device) 281 sharded_saves = [] 282 sharded_prefixes = [] 283 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 284 last_device = None 285 for shard, (device, saveables) in enumerate(per_device): 286 last_device = device 287 with ops.device(saveable_object_util.set_cpu0(device)): 288 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, 289 num_shards_tensor) 290 sharded_prefixes.append(sharded_filename) 291 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 292 293 with ops.control_dependencies([x.op for x in sharded_saves]): 294 # Co-locates the merge step with the last device. 295 with ops.device(saveable_object_util.set_cpu0(last_device)): 296 # V2 format write path consists of a metadata merge step. Once merged, 297 # attempts to delete the temporary directory, "<user-fed prefix>_temp". 298 merge_step = gen_io_ops.merge_v2_checkpoints( 299 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) 300 with ops.control_dependencies([merge_step]): 301 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the 302 # sharded spec suffix. 303 return array_ops.identity(checkpoint_prefix) 304 305 def _AddShardedSaveOps(self, filename_tensor, per_device): 306 """Add ops to save the params per shard. 307 308 Args: 309 filename_tensor: a scalar String Tensor. 310 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as 311 returned by _GroupByDevices(). 312 313 Returns: 314 An op to save the variables. 315 """ 316 if self._write_version == saver_pb2.SaverDef.V2: 317 return self._AddShardedSaveOpsForV2(filename_tensor, per_device) 318 319 num_shards = len(per_device) 320 sharded_saves = [] 321 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 322 for shard, (device, saveables) in enumerate(per_device): 323 with ops.device(device): 324 sharded_filename = self.sharded_filename(filename_tensor, shard, 325 num_shards_tensor) 326 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 327 # Return the sharded name for the save path. 328 with ops.control_dependencies([x.op for x in sharded_saves]): 329 return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor) 330 331 def _AddRestoreOps(self, 332 filename_tensor, 333 saveables, 334 restore_sequentially, 335 reshape, 336 preferred_shard=-1, 337 name="restore_all"): 338 """Add operations to restore saveables. 339 340 Args: 341 filename_tensor: Tensor for the path of the file to load. 342 saveables: A list of SaveableObject objects. 343 restore_sequentially: True if we want to restore variables sequentially 344 within a shard. 345 reshape: True if we want to reshape loaded tensors to the shape of the 346 corresponding variable. 347 preferred_shard: Shard to open first when loading a sharded file. 348 name: Name for the returned op. 349 350 Returns: 351 An Operation that restores the variables. 352 """ 353 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, 354 restore_sequentially) 355 356 assign_ops = [] 357 idx = 0 358 # Load and optionally reshape on the CPU, as string tensors are not 359 # available on the GPU. 360 # TODO(touts): Re-enable restore on GPU when we can support annotating 361 # string tensors as "HostMemory" inputs. 362 for saveable in saveables: 363 shapes = None 364 if reshape: 365 # Compute the shapes, let the restore op decide if and how to do 366 # the reshape. 367 shapes = [] 368 for spec in saveable.specs: 369 v = spec.tensor 370 shape = v.get_shape() 371 if not shape.is_fully_defined(): 372 shape = array_ops.shape(v) 373 shapes.append(shape) 374 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] 375 idx += len(saveable.specs) 376 assign_ops.append(saveable.restore(saveable_tensors, shapes)) 377 378 # Create a Noop that has control dependencies from all the updates. 379 return control_flow_ops.group(*assign_ops, name=name) 380 381 def _AddShardedRestoreOps(self, filename_tensor, per_device, 382 restore_sequentially, reshape): 383 """Add Ops to restore variables from multiple devices. 384 385 Args: 386 filename_tensor: Tensor for the path of the file to load. 387 per_device: A list of (device, SaveableObject) pairs, as returned by 388 _GroupByDevices(). 389 restore_sequentially: True if we want to restore variables sequentially 390 within a shard. 391 reshape: True if we want to reshape loaded tensors to the shape of the 392 corresponding variable. 393 394 Returns: 395 An Operation that restores the variables. 396 """ 397 sharded_restores = [] 398 for shard, (device, saveables) in enumerate(per_device): 399 with ops.device(device): 400 sharded_restores.append( 401 self._AddRestoreOps( 402 filename_tensor, 403 saveables, 404 restore_sequentially, 405 reshape, 406 preferred_shard=shard, 407 name="restore_shard")) 408 return control_flow_ops.group(*sharded_restores, name="restore_all") 409 410 def _GroupByDevices(self, saveables): 411 """Group Variable tensor slices per device. 412 413 TODO(touts): Make sure that all the devices found are on different 414 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. 415 It can happen if the devices are unspecified. 416 417 Args: 418 saveables: A list of BaseSaverBuilder.SaveableObject objects. 419 420 Returns: 421 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples. 422 The list is sorted by ascending device_name. 423 424 Raises: 425 ValueError: If the tensors of a saveable are on different devices. 426 """ 427 per_device = collections.defaultdict(lambda: []) 428 for saveable in saveables: 429 canonical_device = set( 430 pydev.canonical_name(spec.device) for spec in saveable.specs) 431 if len(canonical_device) != 1: 432 raise ValueError("All tensors of a saveable object must be " 433 "on the same device: %s" % saveable.name) 434 per_device[canonical_device.pop()].append(saveable) 435 return sorted(per_device.items(), key=lambda t: t[0]) 436 437 def build(self, 438 names_to_saveables, 439 reshape=False, 440 sharded=False, 441 max_to_keep=5, 442 keep_checkpoint_every_n_hours=10000.0, 443 name=None, 444 restore_sequentially=False, 445 filename="model"): 446 """Builds save/restore graph nodes or runs save/restore in eager mode. 447 448 Args: 449 names_to_saveables: A dictionary mapping name to a Variable or 450 SaveableObject. Each name will be associated with the corresponding 451 variable in the checkpoint. 452 reshape: If True, allow restoring parameters from a checkpoint that where 453 the parameters have a different shape. This is only needed when you try 454 to restore from a Dist-Belief checkpoint, and only some times. 455 sharded: If True, shard the checkpoints, one per device that has Variable 456 nodes. 457 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints 458 are created, old ones are deleted. If None or 0, no checkpoints are 459 deleted from the filesystem but only the last one is kept in the 460 `checkpoint` file. Presently the number is only roughly enforced. For 461 example in case of restarts more than max_to_keep checkpoints may be 462 kept. 463 keep_checkpoint_every_n_hours: How often checkpoints should be kept. 464 Defaults to 10,000 hours. 465 name: String. Optional name to use as a prefix when adding operations. 466 restore_sequentially: A Bool, which if true, causes restore of different 467 variables to happen sequentially within each device. 468 filename: If known at graph construction time, filename used for variable 469 loading/saving. If None, then the default name "model" will be used. 470 471 Returns: 472 A SaverDef proto. 473 474 Raises: 475 TypeError: If 'names_to_saveables' is not a dictionary mapping string 476 keys to variable Tensors. 477 ValueError: If any of the keys or values in 'names_to_saveables' is not 478 unique. 479 """ 480 return self._build_internal( 481 names_to_saveables=names_to_saveables, 482 reshape=reshape, 483 sharded=sharded, 484 max_to_keep=max_to_keep, 485 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 486 name=name, 487 restore_sequentially=restore_sequentially, 488 filename=filename) 489 490 def _build_internal(self, 491 names_to_saveables, 492 reshape=False, 493 sharded=False, 494 max_to_keep=5, 495 keep_checkpoint_every_n_hours=10000.0, 496 name=None, 497 restore_sequentially=False, 498 filename="model", 499 build_save=True, 500 build_restore=True): 501 """build() with option to only perform save and restore.""" 502 if not context.executing_eagerly() and (not build_save or 503 not build_restore): 504 raise ValueError("save and restore operations need to be built together " 505 " when eager execution is not enabled.") 506 507 saveables = saveable_object_util.validate_and_slice_inputs( 508 names_to_saveables) 509 if max_to_keep is None: 510 max_to_keep = 0 511 512 with ops.name_scope(name, "save", 513 [saveable.op for saveable in saveables]) as name: 514 # Add a placeholder string tensor for the filename. 515 filename_tensor = array_ops.placeholder_with_default( 516 filename or "model", shape=(), name="filename") 517 # Keep the name "Const" for backwards compatibility. 518 filename_tensor = array_ops.placeholder_with_default( 519 filename_tensor, shape=(), name="Const") 520 521 # Add the save ops. 522 if sharded: 523 per_device = self._GroupByDevices(saveables) 524 if build_save: 525 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) 526 if build_restore: 527 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, 528 restore_sequentially, reshape) 529 else: 530 if build_save: 531 save_tensor = self._AddSaveOps(filename_tensor, saveables) 532 if build_restore: 533 restore_op = self._AddRestoreOps(filename_tensor, saveables, 534 restore_sequentially, reshape) 535 536 # In the following use case, it's possible to have restore_ops be called 537 # something else: 538 # - Build inference graph and export a meta_graph. 539 # - Import the inference meta_graph 540 # - Extend the inference graph to a train graph. 541 # - Export a new meta_graph. 542 # Now the second restore_op will be called "restore_all_1". 543 # As such, comment out the assert for now until we know whether supporting 544 # such usage model makes sense. 545 # 546 # assert restore_op.name.endswith("restore_all"), restore_op.name 547 if context.executing_eagerly(): 548 # Store the tensor values to the tensor_names. 549 save_tensor_name = save_tensor.numpy() if build_save else "" 550 return saver_pb2.SaverDef( 551 filename_tensor_name=filename_tensor.numpy(), 552 save_tensor_name=save_tensor_name, 553 restore_op_name="", 554 max_to_keep=max_to_keep, 555 sharded=sharded, 556 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 557 version=self._write_version) 558 else: 559 graph = ops.get_default_graph() 560 # Do some sanity checking on collections containing 561 # PartitionedVariables. If a saved collection has a PartitionedVariable, 562 # the GraphDef needs to include concat ops to get the value (or there'll 563 # be a lookup error on load). 564 check_collection_list = graph.get_all_collection_keys() 565 for collection_type in check_collection_list: 566 for element in graph.get_collection(collection_type): 567 if isinstance(element, variables.PartitionedVariable): 568 try: 569 graph.get_operation_by_name(element.name) 570 except KeyError: 571 # Create a concat op for this PartitionedVariable. The user may 572 # not need it, but we'll try looking it up on MetaGraph restore 573 # since it's in a collection. 574 element.as_tensor() 575 return saver_pb2.SaverDef( 576 filename_tensor_name=filename_tensor.name, 577 save_tensor_name=save_tensor.name, 578 restore_op_name=restore_op.name, 579 max_to_keep=max_to_keep, 580 sharded=sharded, 581 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 582 version=self._write_version) 583 584 585class BulkSaverBuilder(BaseSaverBuilder): 586 """SaverBuilder with support for bulk restoring multiple saveables.""" 587 588 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 589 restore_sequentially): 590 591 # Ignored: bulk restore is internally sequential. 592 del restore_sequentially 593 restore_specs = [] 594 for saveable in saveables: 595 for spec in saveable.specs: 596 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 597 598 names, slices, dtypes = zip(*restore_specs) 599 # Load all tensors onto CPU 0 for compatibility with existing code. 600 with ops.device("cpu:0"): 601 return io_ops.restore_v2(filename_tensor, names, slices, dtypes) 602 603 604def _get_saver_or_default(): 605 """Returns the saver from SAVERS collection, or creates a default one. 606 607 This method is used by other members of the training module, such as 608 `Scaffold`, or `CheckpointSaverHook`. 609 610 Returns: 611 `Saver`. 612 613 Raises: 614 RuntimeError: If the SAVERS collection already has more than one items. 615 """ 616 collection_key = ops.GraphKeys.SAVERS 617 savers = ops.get_collection(collection_key) 618 if savers: 619 if len(savers) > 1: 620 raise RuntimeError( 621 "More than one item in collection {}. " 622 "Please indicate which one to use by passing it to the constructor." 623 .format(collection_key)) 624 return savers[0] 625 saver = Saver(sharded=True, allow_empty=True) 626 if saver is not None: 627 ops.add_to_collection(collection_key, saver) 628 return saver 629 630 631@tf_export(v1=["train.Saver"]) 632class Saver(object): 633 # pylint: disable=line-too-long 634 """Saves and restores variables. 635 636 @compatibility(TF2) 637 `tf.compat.v1.train.Saver` is not supported for saving and restoring 638 checkpoints in TF2. Please switch to `tf.train.Checkpoint` or 639 `tf.keras.Model.save_weights`, which perform a more robust [object-based 640 saving](https://www.tensorflow.org/guide/checkpoint#loading_mechanics). 641 642 ### How to Rewrite Checkpoints 643 644 Please rewrite your checkpoints immediately using the object-based checkpoint 645 APIs. 646 647 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver` 648 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However, 649 you may have to change the names of the variables in your model to match the 650 variable names in the name-based checkpoint, which can be viewed with 651 `tf.train.list_variables(path)`. 652 653 Another option is to create an `assignment_map` that maps the name of the 654 variables in the name-based checkpoint to the variables in your model, eg: 655 ``` 656 { 657 'sequential/dense/bias': model.variables[0], 658 'sequential/dense/kernel': model.variables[1] 659 } 660 ``` 661 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to 662 restore the name-based checkpoint. 663 664 After restoring, re-encode your checkpoint 665 using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`. 666 667 See the [Checkpoint compatibility]( 668 https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) 669 section of the migration guide for more details. 670 671 672 ### Checkpoint Management in TF2 673 674 Use `tf.train.CheckpointManager` to manage checkpoints in TF2. 675 `tf.train.CheckpointManager` offers equivalent `keep_checkpoint_every_n_hours` 676 and `max_to_keep` parameters. 677 678 To recover the latest checkpoint, 679 680 ``` 681 checkpoint = tf.train.Checkpoint(model) 682 manager = tf.train.CheckpointManager(checkpoint) 683 status = checkpoint.restore(manager.latest_checkpoint) 684 ``` 685 686 `tf.train.CheckpointManager` also writes a [`CheckpointState` proto] 687 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/checkpoint_state.proto) 688 which contains the timestamp when each checkpoint was created. 689 690 ### Writing `MetaGraphDef`s in TF2 691 692 To replace, `tf.compat.v1.train.Saver.save(write_meta_graph=True)`, use 693 `tf.saved_model.save` to write the `MetaGraphDef` (which is contained in 694 `saved_model.pb`). 695 696 @end_compatibility 697 698 See [Variables](https://tensorflow.org/guide/variables) 699 for an overview of variables, saving and restoring. 700 701 The `Saver` class adds ops to save and restore variables to and from 702 *checkpoints*. It also provides convenience methods to run these ops. 703 704 Checkpoints are binary files in a proprietary format which map variable names 705 to tensor values. The best way to examine the contents of a checkpoint is to 706 load it using a `Saver`. 707 708 Savers can automatically number checkpoint filenames with a provided counter. 709 This lets you keep multiple checkpoints at different steps while training a 710 model. For example you can number the checkpoint filenames with the training 711 step number. To avoid filling up disks, savers manage checkpoint files 712 automatically. For example, they can keep only the N most recent files, or 713 one checkpoint for every N hours of training. 714 715 You number checkpoint filenames by passing a value to the optional 716 `global_step` argument to `save()`: 717 718 ```python 719 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' 720 ... 721 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' 722 ``` 723 724 Additionally, optional arguments to the `Saver()` constructor let you control 725 the proliferation of checkpoint files on disk: 726 727 * `max_to_keep` indicates the maximum number of recent checkpoint files to 728 keep. As new files are created, older files are deleted. If None or 0, 729 no checkpoints are deleted from the filesystem but only the last one is 730 kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent 731 checkpoint files are kept.) 732 733 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent 734 `max_to_keep` checkpoint files, you might want to keep one checkpoint file 735 for every N hours of training. This can be useful if you want to later 736 analyze how a model progressed during a long training session. For 737 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep 738 one checkpoint file for every 2 hours of training. The default value of 739 10,000 hours effectively disables the feature. 740 741 Note that you still have to call the `save()` method to save the model. 742 Passing these arguments to the constructor will not save variables 743 automatically for you. 744 745 A training program that saves regularly looks like: 746 747 ```python 748 ... 749 # Create a saver. 750 saver = tf.compat.v1.train.Saver(...variables...) 751 # Launch the graph and train, saving the model every 1,000 steps. 752 sess = tf.compat.v1.Session() 753 for step in xrange(1000000): 754 sess.run(..training_op..) 755 if step % 1000 == 0: 756 # Append the step number to the checkpoint name: 757 saver.save(sess, 'my-model', global_step=step) 758 ``` 759 760 In addition to checkpoint files, savers keep a protocol buffer on disk with 761 the list of recent checkpoints. This is used to manage numbered checkpoint 762 files and by `latest_checkpoint()`, which makes it easy to discover the path 763 to the most recent checkpoint. That protocol buffer is stored in a file named 764 'checkpoint' next to the checkpoint files. 765 766 If you create several savers, you can specify a different filename for the 767 protocol buffer file in the call to `save()`. 768 """ 769 # pylint: enable=line-too-long 770 771 def __init__(self, 772 var_list=None, 773 reshape=False, 774 sharded=False, 775 max_to_keep=5, 776 keep_checkpoint_every_n_hours=10000.0, 777 name=None, 778 restore_sequentially=False, 779 saver_def=None, 780 builder=None, 781 defer_build=False, 782 allow_empty=False, 783 write_version=saver_pb2.SaverDef.V2, 784 pad_step_number=False, 785 save_relative_paths=False, 786 filename=None): 787 """Creates a `Saver`. 788 789 The constructor adds ops to save and restore variables. 790 791 `var_list` specifies the variables that will be saved and restored. It can 792 be passed as a `dict` or a list: 793 794 * A `dict` of names to variables: The keys are the names that will be 795 used to save or restore the variables in the checkpoint files. 796 * A list of variables: The variables will be keyed with their op name in 797 the checkpoint files. 798 799 For example: 800 801 ```python 802 v1 = tf.Variable(..., name='v1') 803 v2 = tf.Variable(..., name='v2') 804 805 # Pass the variables as a dict: 806 saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2}) 807 808 # Or pass them as a list. 809 saver = tf.compat.v1.train.Saver([v1, v2]) 810 # Passing a list is equivalent to passing a dict with the variable op names 811 # as keys: 812 saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]}) 813 ``` 814 815 Note: the newer `AutoTrackable` API is not supported by `Saver`. In this 816 case, the `tf.train.Checkpoint` class should be used. 817 818 The optional `reshape` argument, if `True`, allows restoring a variable from 819 a save file where the variable had a different shape, but the same number 820 of elements and type. This is useful if you have reshaped a variable and 821 want to reload it from an older checkpoint. 822 823 The optional `sharded` argument, if `True`, instructs the saver to shard 824 checkpoints per device. 825 826 Args: 827 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping 828 names to `SaveableObject`s. If `None`, defaults to the list of all 829 saveable objects. 830 reshape: If `True`, allows restoring parameters from a checkpoint where 831 the variables have a different shape. 832 sharded: If `True`, shard the checkpoints, one per device. 833 max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5. 834 keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to 835 10,000 hours. 836 name: String. Optional name to use as a prefix when adding operations. 837 restore_sequentially: A `Bool`, which if true, causes restore of different 838 variables to happen sequentially within each device. This can lower 839 memory usage when restoring very large models. 840 saver_def: Optional `SaverDef` proto to use instead of running the 841 builder. This is only useful for specialty code that wants to recreate a 842 `Saver` object for a previously built `Graph` that had a `Saver`. The 843 `saver_def` proto should be the one returned by the `as_saver_def()` 844 call of the `Saver` that was created for that `Graph`. 845 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. 846 Defaults to `BulkSaverBuilder()`. 847 defer_build: If `True`, defer adding the save and restore ops to the 848 `build()` call. In that case `build()` should be called before 849 finalizing the graph or using the saver. 850 allow_empty: If `False` (default) raise an error if there are no variables 851 in the graph. Otherwise, construct the saver anyway and make it a no-op. 852 write_version: controls what format to use when saving checkpoints. It 853 also affects certain filepath matching logic. The V2 format is the 854 recommended choice: it is much more optimized than V1 in terms of memory 855 required and latency incurred during restore. Regardless of this 856 flag, the Saver is able to restore from both V2 and V1 checkpoints. 857 pad_step_number: if True, pads the global step number in the checkpoint 858 filepaths to some fixed width (8 by default). This is turned off by 859 default. 860 save_relative_paths: If `True`, will write relative paths to the 861 checkpoint state file. This is needed if the user wants to copy the 862 checkpoint directory and reload from the copied directory. 863 filename: If known at graph construction time, filename used for variable 864 loading/saving. 865 866 Raises: 867 TypeError: If `var_list` is invalid. 868 ValueError: If any of the keys or values in `var_list` are not unique. 869 RuntimeError: If eager execution is enabled and`var_list` does not specify 870 a list of variables to save. 871 872 @compatibility(eager) 873 When eager execution is enabled, `var_list` must specify a `list` or `dict` 874 of variables to save. Otherwise, a `RuntimeError` will be raised. 875 876 Although Saver works in some cases when executing eagerly, it is 877 fragile. Please switch to `tf.train.Checkpoint` or 878 `tf.keras.Model.save_weights`, which perform a more robust object-based 879 saving. These APIs will load checkpoints written by `Saver`. 880 @end_compatibility 881 """ 882 global _END_TIME_OF_LAST_WRITE 883 with _END_TIME_OF_LAST_WRITE_LOCK: 884 if _END_TIME_OF_LAST_WRITE is None: 885 _END_TIME_OF_LAST_WRITE = time.time() 886 887 if defer_build and var_list: 888 raise ValueError( 889 "If `var_list` is provided then build cannot be deferred. " 890 "Either set defer_build=False or var_list=None.") 891 if context.executing_eagerly(): 892 logging.warning( 893 "Saver is deprecated, please switch to tf.train.Checkpoint or " 894 "tf.keras.Model.save_weights for training checkpoints. When " 895 "executing eagerly variables do not necessarily have unique names, " 896 "and so the variable.name-based lookups Saver performs are " 897 "error-prone.") 898 if var_list is None: 899 raise RuntimeError( 900 "When eager execution is enabled, `var_list` must specify a list " 901 "or dict of variables to save") 902 self._var_list = var_list 903 self._reshape = reshape 904 self._sharded = sharded 905 self._max_to_keep = max_to_keep 906 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 907 self._name = name 908 self._restore_sequentially = restore_sequentially 909 self.saver_def = saver_def 910 self._builder = builder 911 self._is_built = False 912 self._allow_empty = allow_empty 913 self._is_empty = None 914 self._write_version = write_version 915 self._pad_step_number = pad_step_number 916 self._filename = filename 917 self._last_checkpoints = [] 918 self._checkpoints_to_be_deleted = [] 919 if context.executing_eagerly(): 920 self._next_checkpoint_time = ( 921 time.time() + self._keep_checkpoint_every_n_hours * 3600) 922 elif not defer_build: 923 self.build() 924 if self.saver_def: 925 self._check_saver_def() 926 self._write_version = self.saver_def.version 927 self._save_relative_paths = save_relative_paths 928 # For compatibility with object-based checkpoints, we may build a second 929 # Saver to read the renamed keys. 930 self._object_restore_saver = None 931 932 def build(self): 933 if context.executing_eagerly(): 934 raise RuntimeError("Use save/restore instead of build in eager mode.") 935 self._build(self._filename, build_save=True, build_restore=True) 936 937 def _build_eager(self, checkpoint_path, build_save, build_restore): 938 self._build( 939 checkpoint_path, build_save=build_save, build_restore=build_restore) 940 941 def _build(self, checkpoint_path, build_save, build_restore): 942 """Builds saver_def.""" 943 if not context.executing_eagerly(): 944 if self._is_built: 945 return 946 self._is_built = True 947 948 if not self.saver_def or context.executing_eagerly(): 949 if self._builder is None: 950 self._builder = BulkSaverBuilder(self._write_version) 951 952 if self._var_list is None: 953 # pylint: disable=protected-access 954 self._var_list = variables._all_saveable_objects() 955 if not self._var_list: 956 if self._allow_empty: 957 self._is_empty = True 958 return 959 else: 960 raise ValueError("No variables to save") 961 self._is_empty = False 962 963 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access 964 self._var_list, 965 reshape=self._reshape, 966 sharded=self._sharded, 967 max_to_keep=self._max_to_keep, 968 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, 969 name=self._name, 970 restore_sequentially=self._restore_sequentially, 971 filename=checkpoint_path, 972 build_save=build_save, 973 build_restore=build_restore) 974 elif self.saver_def and self._name: 975 # Since self._name is used as a name_scope by builder(), we are 976 # overloading the use of this field to represent the "import_scope" as 977 # well. 978 self.saver_def.filename_tensor_name = ops.prepend_name_scope( 979 self.saver_def.filename_tensor_name, self._name) 980 self.saver_def.save_tensor_name = ops.prepend_name_scope( 981 self.saver_def.save_tensor_name, self._name) 982 self.saver_def.restore_op_name = ops.prepend_name_scope( 983 self.saver_def.restore_op_name, self._name) 984 985 self._check_saver_def() 986 if not context.executing_eagerly(): 987 # Updates next checkpoint time. 988 # Set in __init__ when executing eagerly. 989 self._next_checkpoint_time = ( 990 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) 991 992 def _check_saver_def(self): 993 if not isinstance(self.saver_def, saver_pb2.SaverDef): 994 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % 995 self.saver_def) 996 if not context.executing_eagerly(): 997 if not self.saver_def.save_tensor_name: 998 raise ValueError("saver_def must specify the save_tensor_name: %s" % 999 str(self.saver_def)) 1000 if not self.saver_def.restore_op_name: 1001 raise ValueError("saver_def must specify the restore_op_name: %s" % 1002 str(self.saver_def)) 1003 1004 def _CheckpointFilename(self, p): 1005 """Returns the checkpoint filename given a `(filename, time)` pair. 1006 1007 Args: 1008 p: (filename, time) pair. 1009 1010 Returns: 1011 Checkpoint file name. 1012 """ 1013 name, _ = p 1014 return name 1015 1016 def _RecordLastCheckpoint(self, latest_save_path): 1017 """Manages the list of the latest checkpoints.""" 1018 if not self.saver_def.max_to_keep: 1019 return 1020 # Remove first from list if the same name was used before. 1021 for p in self._last_checkpoints: 1022 if latest_save_path == self._CheckpointFilename(p): 1023 self._last_checkpoints.remove(p) 1024 # Append new path to list 1025 self._last_checkpoints.append((latest_save_path, time.time())) 1026 1027 # If more than max_to_keep, remove oldest. 1028 if len(self._last_checkpoints) > self.saver_def.max_to_keep: 1029 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) 1030 1031 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): 1032 """Deletes old checkpoints if necessary. 1033 1034 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are 1035 over `max_to_keep`. They are going to be deleted. If 1036 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint 1037 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is 1038 kept for every 0.5 hours of training; if `N` is 10, an additional 1039 checkpoint is kept for every 10 hours of training. 1040 1041 Args: 1042 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1043 """ 1044 if self._checkpoints_to_be_deleted: 1045 p = self._checkpoints_to_be_deleted.pop(0) 1046 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we 1047 # have reached N hours of training. 1048 should_keep = p[1] > self._next_checkpoint_time 1049 if should_keep: 1050 self._next_checkpoint_time += ( 1051 self.saver_def.keep_checkpoint_every_n_hours * 3600) 1052 return 1053 1054 # Otherwise delete the files. 1055 try: 1056 checkpoint_management.remove_checkpoint( 1057 self._CheckpointFilename(p), self.saver_def.version, 1058 meta_graph_suffix) 1059 except Exception as e: # pylint: disable=broad-except 1060 logging.warning("Ignoring: %s", str(e)) 1061 1062 def as_saver_def(self): 1063 """Generates a `SaverDef` representation of this saver. 1064 1065 Returns: 1066 A `SaverDef` proto. 1067 """ 1068 return self.saver_def 1069 1070 def to_proto(self, export_scope=None): 1071 """Converts this `Saver` to a `SaverDef` protocol buffer. 1072 1073 Args: 1074 export_scope: Optional `string`. Name scope to remove. 1075 1076 Returns: 1077 A `SaverDef` protocol buffer. 1078 """ 1079 if export_scope is None: 1080 return self.saver_def 1081 1082 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and 1083 self.saver_def.save_tensor_name.startswith(export_scope) and 1084 self.saver_def.restore_op_name.startswith(export_scope)): 1085 return None 1086 1087 saver_def = saver_pb2.SaverDef() 1088 saver_def.CopyFrom(self.saver_def) 1089 saver_def.filename_tensor_name = ops.strip_name_scope( 1090 saver_def.filename_tensor_name, export_scope) 1091 saver_def.save_tensor_name = ops.strip_name_scope( 1092 saver_def.save_tensor_name, export_scope) 1093 saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name, 1094 export_scope) 1095 return saver_def 1096 1097 @staticmethod 1098 def from_proto(saver_def, import_scope=None): 1099 """Returns a `Saver` object created from `saver_def`. 1100 1101 Args: 1102 saver_def: a `SaverDef` protocol buffer. 1103 import_scope: Optional `string`. Name scope to use. 1104 1105 Returns: 1106 A `Saver` built from saver_def. 1107 """ 1108 return Saver(saver_def=saver_def, name=import_scope) 1109 1110 @property 1111 def last_checkpoints(self): 1112 """List of not-yet-deleted checkpoint filenames. 1113 1114 You can pass any of the returned values to `restore()`. 1115 1116 Returns: 1117 A list of checkpoint filenames, sorted from oldest to newest. 1118 """ 1119 return list(self._CheckpointFilename(p) for p in self._last_checkpoints) 1120 1121 def set_last_checkpoints(self, last_checkpoints): 1122 """DEPRECATED: Use set_last_checkpoints_with_time. 1123 1124 Sets the list of old checkpoint filenames. 1125 1126 Args: 1127 last_checkpoints: A list of checkpoint filenames. 1128 1129 Raises: 1130 AssertionError: If last_checkpoints is not a list. 1131 """ 1132 assert isinstance(last_checkpoints, list) 1133 # We use a timestamp of +inf so that this checkpoint will never be 1134 # deleted. This is both safe and backwards compatible to a previous 1135 # version of the code which used s[1] as the "timestamp". 1136 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] 1137 1138 def set_last_checkpoints_with_time(self, last_checkpoints_with_time): 1139 """Sets the list of old checkpoint filenames and timestamps. 1140 1141 Args: 1142 last_checkpoints_with_time: A list of tuples of checkpoint filenames and 1143 timestamps. 1144 1145 Raises: 1146 AssertionError: If last_checkpoints_with_time is not a list. 1147 """ 1148 assert isinstance(last_checkpoints_with_time, list) 1149 self._last_checkpoints = last_checkpoints_with_time 1150 1151 def recover_last_checkpoints(self, checkpoint_paths): 1152 """Recovers the internal saver state after a crash. 1153 1154 This method is useful for recovering the "self._last_checkpoints" state. 1155 1156 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files 1157 exist, use their mtime as the checkpoint timestamp. 1158 1159 Args: 1160 checkpoint_paths: a list of checkpoint paths. 1161 """ 1162 checkpoints_with_mtimes = [] 1163 for checkpoint_path in checkpoint_paths: 1164 try: 1165 mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path]) 1166 except errors.NotFoundError: 1167 # It's fine if some other thread/process is deleting some older 1168 # checkpoint concurrently. 1169 continue 1170 if mtime: 1171 checkpoints_with_mtimes.append((checkpoint_path, mtime[0])) 1172 self.set_last_checkpoints_with_time(checkpoints_with_mtimes) 1173 1174 def save(self, 1175 sess, 1176 save_path, 1177 global_step=None, 1178 latest_filename=None, 1179 meta_graph_suffix="meta", 1180 write_meta_graph=True, 1181 write_state=True, 1182 strip_default_attrs=False, 1183 save_debug_info=False): 1184 # pylint: disable=line-too-long 1185 """Saves variables. 1186 1187 This method runs the ops added by the constructor for saving variables. 1188 It requires a session in which the graph was launched. The variables to 1189 save must also have been initialized. 1190 1191 The method returns the path prefix of the newly created checkpoint files. 1192 This string can be passed directly to a call to `restore()`. 1193 1194 Args: 1195 sess: A Session to use to save the variables. 1196 save_path: String. Prefix of filenames created for the checkpoint. 1197 global_step: If provided the global step number is appended to `save_path` 1198 to create the checkpoint filenames. The optional argument can be a 1199 `Tensor`, a `Tensor` name or an integer. 1200 latest_filename: Optional name for the protocol buffer file that will 1201 contains the list of most recent checkpoints. That file, kept in the 1202 same directory as the checkpoint files, is automatically managed by the 1203 saver to keep track of recent checkpoints. Defaults to 'checkpoint'. 1204 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1205 write_meta_graph: `Boolean` indicating whether or not to write the meta 1206 graph file. 1207 write_state: `Boolean` indicating whether or not to write the 1208 `CheckpointStateProto`. 1209 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1210 removed from the NodeDefs. For a detailed guide, see 1211 [Stripping Default-Valued 1212 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1213 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1214 which in the same directory of save_path and with `_debug` added before 1215 the file extension. This is only enabled when `write_meta_graph` is 1216 `True` 1217 1218 Returns: 1219 A string: path prefix used for the checkpoint files. If the saver is 1220 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' 1221 is the number of shards created. 1222 If the saver is empty, returns None. 1223 1224 Raises: 1225 TypeError: If `sess` is not a `Session`. 1226 ValueError: If `latest_filename` contains path components, or if it 1227 collides with `save_path`. 1228 RuntimeError: If save and restore ops weren't built. 1229 """ 1230 # pylint: enable=line-too-long 1231 start_time = time.time() 1232 if not self._is_built and not context.executing_eagerly(): 1233 raise RuntimeError( 1234 "`build()` should be called before save if defer_build==True") 1235 if latest_filename is None: 1236 latest_filename = "checkpoint" 1237 if self._write_version != saver_pb2.SaverDef.V2: 1238 logging.warning("*******************************************************") 1239 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") 1240 logging.warning("Consider switching to the more efficient V2 format:") 1241 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") 1242 logging.warning("now on by default.") 1243 logging.warning("*******************************************************") 1244 1245 if os.path.split(latest_filename)[0]: 1246 raise ValueError("'latest_filename' must not contain path components") 1247 1248 save_path = compat.as_str(save_path) 1249 if global_step is not None: 1250 if not isinstance(global_step, compat.integral_types): 1251 global_step = training_util.global_step(sess, global_step) 1252 checkpoint_file = "%s-%d" % (save_path, global_step) 1253 if self._pad_step_number: 1254 # Zero-pads the step numbers, so that they are sorted when listed. 1255 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) 1256 else: 1257 checkpoint_file = save_path 1258 if os.path.basename(save_path) == latest_filename and not self._sharded: 1259 # Guard against collision between data file and checkpoint state file. 1260 raise ValueError( 1261 "'latest_filename' collides with 'save_path': '%s' and '%s'" % 1262 (latest_filename, save_path)) 1263 1264 if (not context.executing_eagerly() and 1265 not isinstance(sess, session.SessionInterface)): 1266 raise TypeError("'sess' must be a Session; %s" % sess) 1267 1268 save_path_parent = os.path.dirname(save_path) 1269 if not self._is_empty: 1270 try: 1271 if context.executing_eagerly(): 1272 self._build_eager( 1273 checkpoint_file, build_save=True, build_restore=False) 1274 model_checkpoint_path = self.saver_def.save_tensor_name 1275 else: 1276 model_checkpoint_path = sess.run( 1277 self.saver_def.save_tensor_name, 1278 {self.saver_def.filename_tensor_name: checkpoint_file}) 1279 1280 model_checkpoint_path = compat.as_str(model_checkpoint_path) 1281 if write_state: 1282 self._RecordLastCheckpoint(model_checkpoint_path) 1283 checkpoint_management.update_checkpoint_state_internal( 1284 save_dir=save_path_parent, 1285 model_checkpoint_path=model_checkpoint_path, 1286 all_model_checkpoint_paths=self.last_checkpoints, 1287 latest_filename=latest_filename, 1288 save_relative_paths=self._save_relative_paths) 1289 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) 1290 except (errors.FailedPreconditionError, errors.NotFoundError) as exc: 1291 if not gfile.IsDirectory(save_path_parent): 1292 exc = ValueError( 1293 "Parent directory of {} doesn't exist, can't save.".format( 1294 save_path)) 1295 raise exc 1296 1297 end_time = time.time() 1298 metrics.AddCheckpointWriteDuration( 1299 api_label=_SAVER_LABEL, 1300 microseconds=_get_duration_microseconds(start_time, end_time)) 1301 global _END_TIME_OF_LAST_WRITE 1302 with _END_TIME_OF_LAST_WRITE_LOCK: 1303 metrics.AddTrainingTimeSaved( 1304 api_label=_SAVER_LABEL, 1305 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 1306 end_time)) 1307 _END_TIME_OF_LAST_WRITE = end_time 1308 1309 if write_meta_graph: 1310 meta_graph_filename = checkpoint_management.meta_graph_filename( 1311 checkpoint_file, meta_graph_suffix=meta_graph_suffix) 1312 if not context.executing_eagerly(): 1313 with sess.graph.as_default(): 1314 self.export_meta_graph( 1315 meta_graph_filename, 1316 strip_default_attrs=strip_default_attrs, 1317 save_debug_info=save_debug_info) 1318 1319 if self._is_empty: 1320 return None 1321 else: 1322 return model_checkpoint_path 1323 1324 def export_meta_graph(self, 1325 filename=None, 1326 collection_list=None, 1327 as_text=False, 1328 export_scope=None, 1329 clear_devices=False, 1330 clear_extraneous_savers=False, 1331 strip_default_attrs=False, 1332 save_debug_info=False): 1333 # pylint: disable=line-too-long 1334 """Writes `MetaGraphDef` to save_path/filename. 1335 1336 Args: 1337 filename: Optional meta_graph filename including the path. 1338 collection_list: List of string keys to collect. 1339 as_text: If `True`, writes the meta_graph as an ASCII proto. 1340 export_scope: Optional `string`. Name scope to remove. 1341 clear_devices: Whether or not to clear the device field for an `Operation` 1342 or `Tensor` during export. 1343 clear_extraneous_savers: Remove any Saver-related information from the 1344 graph (both Save/Restore ops and SaverDefs) that are not associated with 1345 this Saver. 1346 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1347 removed from the NodeDefs. For a detailed guide, see 1348 [Stripping Default-Valued 1349 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1350 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1351 which in the same directory of filename and with `_debug` added before 1352 the file extension. 1353 1354 Returns: 1355 A `MetaGraphDef` proto. 1356 """ 1357 # pylint: enable=line-too-long 1358 return export_meta_graph( 1359 filename=filename, 1360 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 1361 saver_def=self.saver_def, 1362 collection_list=collection_list, 1363 as_text=as_text, 1364 export_scope=export_scope, 1365 clear_devices=clear_devices, 1366 clear_extraneous_savers=clear_extraneous_savers, 1367 strip_default_attrs=strip_default_attrs, 1368 save_debug_info=save_debug_info) 1369 1370 def restore(self, sess, save_path): 1371 """Restores previously saved variables. 1372 1373 This method runs the ops added by the constructor for restoring variables. 1374 It requires a session in which the graph was launched. The variables to 1375 restore do not have to have been initialized, as restoring is itself a way 1376 to initialize variables. 1377 1378 The `save_path` argument is typically a value previously returned from a 1379 `save()` call, or a call to `latest_checkpoint()`. 1380 1381 Args: 1382 sess: A `Session` to use to restore the parameters. None in eager mode. 1383 save_path: Path where parameters were previously saved. 1384 1385 Raises: 1386 ValueError: If save_path is None or not a valid checkpoint. 1387 """ 1388 start_time = time.time() 1389 if self._is_empty: 1390 return 1391 if save_path is None: 1392 raise ValueError("Can't load save_path when it is None.") 1393 1394 checkpoint_prefix = compat.as_text(save_path) 1395 if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): 1396 raise ValueError("The passed save_path is not a valid checkpoint: " + 1397 checkpoint_prefix) 1398 1399 logging.info("Restoring parameters from %s", checkpoint_prefix) 1400 try: 1401 if context.executing_eagerly(): 1402 self._build_eager(save_path, build_save=False, build_restore=True) 1403 else: 1404 sess.run(self.saver_def.restore_op_name, 1405 {self.saver_def.filename_tensor_name: save_path}) 1406 except errors.NotFoundError as err: 1407 # There are three common conditions that might cause this error: 1408 # 0. The file is missing. We ignore here, as this is checked above. 1409 # 1. This is an object-based checkpoint trying name-based loading. 1410 # 2. The graph has been altered and a variable or other name is missing. 1411 1412 # 1. The checkpoint would not be loaded successfully as is. Try to parse 1413 # it as an object-based checkpoint. 1414 try: 1415 names_to_keys = object_graph_key_mapping(save_path) 1416 except errors.NotFoundError: 1417 # 2. This is not an object-based checkpoint, which likely means there 1418 # is a graph mismatch. Re-raise the original error with 1419 # a helpful message (b/110263146) 1420 raise _wrap_restore_error_with_msg( 1421 err, "a Variable name or other graph key that is missing") 1422 1423 # This is an object-based checkpoint. We'll print a warning and then do 1424 # the restore. 1425 logging.warning( 1426 "Restoring an object-based checkpoint using a name-based saver. This " 1427 "may be somewhat fragile, and will re-build the Saver. Instead, " 1428 "consider loading object-based checkpoints using " 1429 "tf.train.Checkpoint().") 1430 self._object_restore_saver = saver_from_object_based_checkpoint( 1431 checkpoint_path=save_path, 1432 var_list=self._var_list, 1433 builder=self._builder, 1434 names_to_keys=names_to_keys, 1435 cached_saver=self._object_restore_saver) 1436 self._object_restore_saver.restore(sess=sess, save_path=save_path) 1437 except errors.InvalidArgumentError as err: 1438 # There is a mismatch between the graph and the checkpoint being loaded. 1439 # We add a more reasonable error message here to help users (b/110263146) 1440 raise _wrap_restore_error_with_msg( 1441 err, "a mismatch between the current graph and the graph") 1442 metrics.AddCheckpointReadDuration( 1443 api_label=_SAVER_LABEL, 1444 microseconds=_get_duration_microseconds(start_time, time.time())) 1445 1446 @staticmethod 1447 def _add_collection_def(meta_graph_def, key, export_scope=None): 1448 """Adds a collection to MetaGraphDef protocol buffer. 1449 1450 Args: 1451 meta_graph_def: MetaGraphDef protocol buffer. 1452 key: One of the GraphKeys or user-defined string. 1453 export_scope: Optional `string`. Name scope to remove. 1454 """ 1455 meta_graph.add_collection_def( 1456 meta_graph_def, key, export_scope=export_scope) 1457 1458 1459@tf_export(v1=["train.import_meta_graph"]) 1460def import_meta_graph(meta_graph_or_file, 1461 clear_devices=False, 1462 import_scope=None, 1463 **kwargs): 1464 """Recreates a Graph saved in a `MetaGraphDef` proto. 1465 1466 This function takes a `MetaGraphDef` protocol buffer as input. If 1467 the argument is a file containing a `MetaGraphDef` protocol buffer , 1468 it constructs a protocol buffer from the file content. The function 1469 then adds all the nodes from the `graph_def` field to the 1470 current graph, recreates all the collections, and returns a saver 1471 constructed from the `saver_def` field. 1472 1473 In combination with `export_meta_graph()`, this function can be used to 1474 1475 * Serialize a graph along with other Python objects such as `QueueRunner`, 1476 `Variable` into a `MetaGraphDef`. 1477 1478 * Restart training from a saved graph and checkpoints. 1479 1480 * Run inference from a saved graph and checkpoints. 1481 1482 ```Python 1483 ... 1484 # Create a saver. 1485 saver = tf.compat.v1.train.Saver(...variables...) 1486 # Remember the training_op we want to run by adding it to a collection. 1487 tf.compat.v1.add_to_collection('train_op', train_op) 1488 sess = tf.compat.v1.Session() 1489 for step in xrange(1000000): 1490 sess.run(train_op) 1491 if step % 1000 == 0: 1492 # Saves checkpoint, which by default also exports a meta_graph 1493 # named 'my-model-global_step.meta'. 1494 saver.save(sess, 'my-model', global_step=step) 1495 ``` 1496 1497 Later we can continue training from this saved `meta_graph` without building 1498 the model from scratch. 1499 1500 ```Python 1501 with tf.Session() as sess: 1502 new_saver = 1503 tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 1504 new_saver.restore(sess, 'my-save-dir/my-model-10000') 1505 # tf.get_collection() returns a list. In this example we only want 1506 # the first one. 1507 train_op = tf.get_collection('train_op')[0] 1508 for step in xrange(1000000): 1509 sess.run(train_op) 1510 ``` 1511 1512 NOTE: Restarting training from saved `meta_graph` only works if the 1513 device assignments have not changed. 1514 1515 Example: 1516 Variables, placeholders, and independent operations can also be stored, as 1517 shown in the following example. 1518 1519 ```Python 1520 # Saving contents and operations. 1521 v1 = tf.placeholder(tf.float32, name="v1") 1522 v2 = tf.placeholder(tf.float32, name="v2") 1523 v3 = tf.math.multiply(v1, v2) 1524 vx = tf.Variable(10.0, name="vx") 1525 v4 = tf.add(v3, vx, name="v4") 1526 saver = tf.train.Saver([vx]) 1527 sess = tf.Session() 1528 sess.run(tf.global_variables_initializer()) 1529 sess.run(vx.assign(tf.add(vx, vx))) 1530 result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) 1531 print(result) 1532 saver.save(sess, "./model_ex1") 1533 ``` 1534 1535 Later this model can be restored and contents loaded. 1536 1537 ```Python 1538 # Restoring variables and running operations. 1539 saver = tf.train.import_meta_graph("./model_ex1.meta") 1540 sess = tf.Session() 1541 saver.restore(sess, "./model_ex1") 1542 result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 1543 print(result) 1544 ``` 1545 1546 Args: 1547 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 1548 the path) containing a `MetaGraphDef`. 1549 clear_devices: Whether or not to clear the device field for an `Operation` 1550 or `Tensor` during import. 1551 import_scope: Optional `string`. Name scope to add. Only used when 1552 initializing from protocol buffer. 1553 **kwargs: Optional keyed arguments. 1554 1555 Returns: 1556 A saver constructed from `saver_def` in `MetaGraphDef` or None. 1557 1558 A None value is returned if no variables exist in the `MetaGraphDef` 1559 (i.e., there are no variables to restore). 1560 1561 Raises: 1562 RuntimeError: If called with eager execution enabled. 1563 1564 @compatibility(eager) 1565 Exporting/importing meta graphs is not supported. No graph exists when eager 1566 execution is enabled. 1567 @end_compatibility 1568 """ # pylint: disable=g-doc-exception 1569 return _import_meta_graph_with_return_elements(meta_graph_or_file, 1570 clear_devices, import_scope, 1571 **kwargs)[0] 1572 1573 1574def _import_meta_graph_with_return_elements(meta_graph_or_file, 1575 clear_devices=False, 1576 import_scope=None, 1577 return_elements=None, 1578 **kwargs): 1579 """Import MetaGraph, and return both a saver and returned elements.""" 1580 if context.executing_eagerly(): 1581 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1582 "eager execution is enabled. No graph exists when eager " 1583 "execution is enabled.") 1584 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 1585 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) 1586 else: 1587 meta_graph_def = meta_graph_or_file 1588 1589 imported_vars, imported_return_elements = ( 1590 meta_graph.import_scoped_meta_graph_with_return_elements( 1591 meta_graph_def, 1592 clear_devices=clear_devices, 1593 import_scope=import_scope, 1594 return_elements=return_elements, 1595 **kwargs)) 1596 1597 saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1598 imported_vars) 1599 return saver, imported_return_elements 1600 1601 1602def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1603 imported_vars): 1604 """Return a saver for restoring variable values to an imported MetaGraph.""" 1605 if meta_graph_def.HasField("saver_def"): 1606 # Infer the scope that is prepended by `import_scoped_meta_graph`. 1607 scope = import_scope 1608 var_names = list(imported_vars.keys()) 1609 if var_names: 1610 sample_key = var_names[0] 1611 sample_var = imported_vars[sample_key] 1612 scope = sample_var.name[:-len(sample_key)] 1613 1614 return Saver(saver_def=meta_graph_def.saver_def, name=scope) 1615 else: 1616 if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access 1617 # Return the default saver instance for all graph variables. 1618 return Saver() 1619 else: 1620 # If no graph variables exist, then a Saver cannot be constructed. 1621 logging.info("Saver not created because there are no variables in the" 1622 " graph to restore") 1623 return None 1624 1625 1626@tf_export(v1=["train.export_meta_graph"]) 1627def export_meta_graph(filename=None, 1628 meta_info_def=None, 1629 graph_def=None, 1630 saver_def=None, 1631 collection_list=None, 1632 as_text=False, 1633 graph=None, 1634 export_scope=None, 1635 clear_devices=False, 1636 clear_extraneous_savers=False, 1637 strip_default_attrs=False, 1638 save_debug_info=False, 1639 **kwargs): 1640 # pylint: disable=line-too-long 1641 """Returns `MetaGraphDef` proto. 1642 1643 Optionally writes it to filename. 1644 1645 This function exports the graph, saver, and collection objects into 1646 `MetaGraphDef` protocol buffer with the intention of it being imported 1647 at a later time or location to restart training, run inference, or be 1648 a subgraph. 1649 1650 Args: 1651 filename: Optional filename including the path for writing the generated 1652 `MetaGraphDef` protocol buffer. 1653 meta_info_def: `MetaInfoDef` protocol buffer. 1654 graph_def: `GraphDef` protocol buffer. 1655 saver_def: `SaverDef` protocol buffer. 1656 collection_list: List of string keys to collect. 1657 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 1658 graph: The `Graph` to export. If `None`, use the default graph. 1659 export_scope: Optional `string`. Name scope under which to extract the 1660 subgraph. The scope name will be striped from the node definitions for 1661 easy import later into new name scopes. If `None`, the whole graph is 1662 exported. graph_def and export_scope cannot both be specified. 1663 clear_devices: Whether or not to clear the device field for an `Operation` 1664 or `Tensor` during export. 1665 clear_extraneous_savers: Remove any Saver-related information from the graph 1666 (both Save/Restore ops and SaverDefs) that are not associated with the 1667 provided SaverDef. 1668 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1669 removed from the NodeDefs. For a detailed guide, see 1670 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1671 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1672 which in the same directory of filename and with `_debug` added before the 1673 file extend. 1674 **kwargs: Optional keyed arguments. 1675 1676 Returns: 1677 A `MetaGraphDef` proto. 1678 1679 Raises: 1680 ValueError: When the `GraphDef` is larger than 2GB. 1681 RuntimeError: If called with eager execution enabled. 1682 1683 @compatibility(eager) 1684 Exporting/importing meta graphs is not supported unless both `graph_def` and 1685 `graph` are provided. No graph exists when eager execution is enabled. 1686 @end_compatibility 1687 """ 1688 # pylint: enable=line-too-long 1689 if context.executing_eagerly() and not (graph_def is not None and 1690 graph is not None): 1691 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1692 "eager execution is enabled. No graph exists when eager " 1693 "execution is enabled.") 1694 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 1695 filename=filename, 1696 meta_info_def=meta_info_def, 1697 graph_def=graph_def, 1698 saver_def=saver_def, 1699 collection_list=collection_list, 1700 as_text=as_text, 1701 graph=graph, 1702 export_scope=export_scope, 1703 clear_devices=clear_devices, 1704 clear_extraneous_savers=clear_extraneous_savers, 1705 strip_default_attrs=strip_default_attrs, 1706 save_debug_info=save_debug_info, 1707 **kwargs) 1708 return meta_graph_def 1709 1710 1711def _wrap_restore_error_with_msg(err, extra_verbiage): 1712 err_msg = ("Restoring from checkpoint failed. This is most likely " 1713 "due to {} from the checkpoint. Please ensure that you " 1714 "have not altered the graph expected based on the checkpoint. " 1715 "Original error:\n\n{}").format(extra_verbiage, err.message) 1716 return err.__class__(err.node_def, err.op, err_msg) 1717 1718 1719ops.register_proto_function( 1720 ops.GraphKeys.SAVERS, 1721 proto_type=saver_pb2.SaverDef, 1722 to_proto=Saver.to_proto, 1723 from_proto=Saver.from_proto) 1724 1725 1726def object_graph_key_mapping(checkpoint_path): 1727 """Return name to key mappings from the checkpoint. 1728 1729 Args: 1730 checkpoint_path: string, path to object-based checkpoint 1731 1732 Returns: 1733 Dictionary mapping tensor names to checkpoint keys. 1734 """ 1735 reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) 1736 object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY) 1737 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1738 object_graph_proto.ParseFromString(object_graph_string) 1739 names_to_keys = {} 1740 for node in object_graph_proto.nodes: 1741 for attribute in node.attributes: 1742 names_to_keys[attribute.full_name] = attribute.checkpoint_key 1743 return names_to_keys 1744 1745 1746def saver_from_object_based_checkpoint(checkpoint_path, 1747 var_list=None, 1748 builder=None, 1749 names_to_keys=None, 1750 cached_saver=None): 1751 """Return a `Saver` which reads from an object-based checkpoint. 1752 1753 This function validates that all variables in the variables list are remapped 1754 in the object-based checkpoint (or `names_to_keys` dict if provided). A 1755 saver will be created with the list of remapped variables. 1756 1757 The `cached_saver` argument allows the user to pass in a previously created 1758 saver, so multiple `saver.restore()` calls don't pollute the graph when graph 1759 building. This assumes that keys are consistent, meaning that the 1760 1) `checkpoint_path` checkpoint, and 1761 2) checkpoint used to create the `cached_saver` 1762 are the same type of object-based checkpoint. If this argument is set, this 1763 function will simply validate that all variables have been remapped by the 1764 checkpoint at `checkpoint_path`. 1765 1766 Note that in general, `tf.train.Checkpoint` should be used to restore/save an 1767 object-based checkpoint. 1768 1769 Args: 1770 checkpoint_path: string, path to object-based checkpoint 1771 var_list: list of `Variables` that appear in the checkpoint. If `None`, 1772 `var_list` will be set to all saveable objects. 1773 builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder` 1774 will be created. 1775 names_to_keys: dict mapping string tensor names to checkpoint keys. If 1776 `None`, this dict will be generated from the checkpoint file. 1777 cached_saver: Cached `Saver` object with remapped variables. 1778 1779 Returns: 1780 `Saver` with remapped variables for reading from an object-based checkpoint. 1781 1782 Raises: 1783 ValueError if the checkpoint provided is not an object-based checkpoint. 1784 NotFoundError: If one of the variables in `var_list` can not be found in the 1785 checkpoint. This could mean the checkpoint or `names_to_keys` mapping is 1786 missing the variable. 1787 """ 1788 if names_to_keys is None: 1789 try: 1790 names_to_keys = object_graph_key_mapping(checkpoint_path) 1791 except errors.NotFoundError: 1792 raise ValueError("Checkpoint in %s not an object-based checkpoint." % 1793 checkpoint_path) 1794 if var_list is None: 1795 var_list = variables._all_saveable_objects() # pylint: disable=protected-access 1796 if builder is None: 1797 builder = BulkSaverBuilder() 1798 1799 saveables = saveable_object_util.validate_and_slice_inputs(var_list) 1800 current_names = set() 1801 for saveable in saveables: 1802 for spec in saveable.specs: 1803 current_names.add(spec.name) 1804 previous_names = set(names_to_keys.keys()) 1805 missing_names = current_names - previous_names 1806 if missing_names: 1807 extra_names = previous_names - current_names 1808 intersecting_names = previous_names.intersection(current_names) 1809 raise errors.NotFoundError( 1810 None, 1811 None, 1812 message=( 1813 "\n\nExisting variables not in the checkpoint: %s\n\n" 1814 "Variables names when this checkpoint was written which don't " 1815 "exist now: %s\n\n" 1816 "(%d variable name(s) did match)\n\n" 1817 "Could not find some variables in the checkpoint (see names " 1818 "above). Saver was attempting to load an object-based checkpoint " 1819 "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) " 1820 "using variable names. If the checkpoint was written with eager " 1821 "execution enabled, it's possible that variable names have " 1822 "changed (for example missing a '_1' suffix). It's also " 1823 "possible that there are new variables which did not exist " 1824 "when the checkpoint was written. You can construct a " 1825 "Saver(var_list=...) with only the variables which previously " 1826 "existed, and if variable names have changed you may need to " 1827 "make this a dictionary with the old names as keys. If you're " 1828 "using an Estimator, you'll need to return a tf.train.Saver " 1829 "inside a tf.train.Scaffold from your model_fn.") % 1830 (", ".join(sorted(missing_names)), ", ".join( 1831 sorted(extra_names)), len(intersecting_names))) 1832 for saveable in saveables: 1833 for spec in saveable.specs: 1834 spec.name = names_to_keys[spec.name] 1835 if cached_saver is None: 1836 return Saver(saveables) 1837 return cached_saver 1838