• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables.
18
19Symbols in this file are deprecated. See replacements in
20tensorflow/python/training/trackable and tensorflow/python/training/saving.
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27import os.path
28import threading
29import time
30
31import numpy as np
32from tensorflow.core.protobuf import meta_graph_pb2
33from tensorflow.core.protobuf import saver_pb2
34from tensorflow.core.protobuf import trackable_object_graph_pb2
35from tensorflow.python.client import session
36from tensorflow.python.eager import context
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import device as pydev
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import meta_graph
41from tensorflow.python.framework import ops
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import gen_io_ops
45from tensorflow.python.ops import io_ops
46from tensorflow.python.ops import string_ops
47from tensorflow.python.ops import variables
48from tensorflow.python.platform import gfile
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.saved_model.pywrap_saved_model import metrics
51from tensorflow.python.training import checkpoint_management
52from tensorflow.python.training import py_checkpoint_reader
53from tensorflow.python.training import training_util
54from tensorflow.python.training.saving import saveable_object
55from tensorflow.python.training.saving import saveable_object_util
56from tensorflow.python.training.tracking import base as trackable
57from tensorflow.python.util import compat
58from tensorflow.python.util.tf_export import tf_export
59
60# TODO(allenl): Remove these aliases once all users are migrated off.
61get_checkpoint_state = checkpoint_management.get_checkpoint_state
62update_checkpoint_state = checkpoint_management.update_checkpoint_state
63generate_checkpoint_state_proto = (
64    checkpoint_management.generate_checkpoint_state_proto)
65latest_checkpoint = checkpoint_management.latest_checkpoint
66checkpoint_exists = checkpoint_management.checkpoint_exists
67get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
68remove_checkpoint = checkpoint_management.remove_checkpoint
69
70
71# Captures the timestamp of the first Saver object instantiation or end of a
72# save operation. Can be accessed by multiple Saver instances.
73_END_TIME_OF_LAST_WRITE = None
74_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
75
76# API label for cell name used in checkpoint metrics.
77_SAVER_LABEL = "saver_v1"
78
79
80def _get_duration_microseconds(start_time_seconds, end_time_seconds):
81  if end_time_seconds < start_time_seconds:
82    # Avoid returning negative value in case of clock skew.
83    return 0
84  return round((end_time_seconds - start_time_seconds) * 1000000)
85
86
87class BaseSaverBuilder(object):
88  """Base class for Savers.
89
90  Can be extended to create different Ops.
91  """
92
93  SaveSpec = saveable_object.SaveSpec
94  SaveableObject = saveable_object.SaveableObject
95
96  # Aliases for code which was moved but still has lots of users.
97  VariableSaveable = saveable_object_util.ReferenceVariableSaveable
98  ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
99
100  def __init__(self, write_version=saver_pb2.SaverDef.V2):
101    self._write_version = write_version
102
103  def save_op(self, filename_tensor, saveables):
104    """Create an Op to save 'saveables'.
105
106    This is intended to be overridden by subclasses that want to generate
107    different Ops.
108
109    Args:
110      filename_tensor: String Tensor.
111      saveables: A list of BaseSaverBuilder.SaveableObject objects.
112
113    Returns:
114      An Operation that save the variables.
115
116    Raises:
117      RuntimeError: (implementation detail) if "self._write_version" is an
118        unexpected value.
119    """
120    # pylint: disable=protected-access
121    tensor_names = []
122    tensors = []
123    tensor_slices = []
124    for saveable in saveables:
125      for spec in saveable.specs:
126        tensor_names.append(spec.name)
127        tensors.append(spec.tensor)
128        tensor_slices.append(spec.slice_spec)
129    if self._write_version == saver_pb2.SaverDef.V1:
130      return io_ops._save(
131          filename=filename_tensor,
132          tensor_names=tensor_names,
133          tensors=tensors,
134          tensor_slices=tensor_slices)
135    elif self._write_version == saver_pb2.SaverDef.V2:
136      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
137      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
138      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
139                            tensors)
140    else:
141      raise RuntimeError("Unexpected write_version: " + self._write_version)
142
143  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
144                   restore_sequentially):
145    """Restore all tensors contained in saveables.
146
147    By default, this issues separate calls to `restore_op` for each saveable.
148    Subclasses may override to load multiple saveables in a single call.
149
150    Args:
151      filename_tensor: String Tensor.
152      saveables: List of BaseSaverBuilder.SaveableObject objects.
153      preferred_shard: Int.  Shard to open first when loading a sharded file.
154      restore_sequentially: Unused.  Bool.  If true, each restore is sequential.
155
156    Returns:
157      A list of Tensors resulting from reading 'saveable' from
158        'filename'.
159
160    """
161    del restore_sequentially
162    all_tensors = []
163    for saveable in saveables:
164      if saveable.device:
165        device = saveable_object_util.set_cpu0(saveable.device)
166      else:
167        device = None
168      with ops.device(device):
169        all_tensors.extend(
170            self.restore_op(filename_tensor, saveable, preferred_shard))
171    return all_tensors
172
173  # pylint: disable=unused-argument
174  def restore_op(self, filename_tensor, saveable, preferred_shard):
175    """Create ops to restore 'saveable'.
176
177    This is intended to be overridden by subclasses that want to generate
178    different Ops.
179
180    Args:
181      filename_tensor: String Tensor.
182      saveable: A BaseSaverBuilder.SaveableObject object.
183      preferred_shard: Int.  Shard to open first when loading a sharded file.
184
185    Returns:
186      A list of Tensors resulting from reading 'saveable' from
187        'filename'.
188    """
189    # pylint: disable=protected-access
190    tensors = []
191    for spec in saveable.specs:
192      tensors.append(
193          io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
194                            [spec.dtype])[0])
195
196    return tensors
197
198  # pylint: enable=unused-argument
199
200  def sharded_filename(self, filename_tensor, shard, num_shards):
201    """Append sharding information to a filename.
202
203    Args:
204      filename_tensor: A string tensor.
205      shard: Integer.  The shard for the filename.
206      num_shards: An int Tensor for the number of shards.
207
208    Returns:
209      A string tensor.
210    """
211    return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
212
213  def _AddSaveOps(self, filename_tensor, saveables):
214    """Add ops to save variables that are on the same shard.
215
216    Args:
217      filename_tensor: String Tensor.
218      saveables: A list of SaveableObject objects.
219
220    Returns:
221      A tensor with the filename used to save.
222    """
223    save = self.save_op(filename_tensor, saveables)
224    return control_flow_ops.with_dependencies([save], filename_tensor)
225
226  def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
227    """Add ops to save the params per shard, for the V2 format.
228
229    Note that the sharded save procedure for the V2 format is different from
230    V1: there is a special "merge" step that merges the small metadata produced
231    from each device.
232
233    Args:
234      checkpoint_prefix: scalar String Tensor.  Interpreted *NOT AS A FILENAME*,
235        but as a prefix of a V2 checkpoint;
236      per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
237        returned by _GroupByDevices().
238
239    Returns:
240      An op to save the variables, which, when evaluated, returns the prefix
241        "<user-fed prefix>" only and does not include the sharded spec suffix.
242    """
243    # IMPLEMENTATION DETAILS: most clients should skip.
244    #
245    # Suffix for any well-formed "checkpoint_prefix", when sharded.
246    # Transformations:
247    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
248    # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
249    # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it
250    # * Otherwise _temp/part is appended which is normalized relative to the OS
251    # Example:
252    #   During runtime, a temporary directory is first created, which contains
253    #   files
254    #
255    #     <train dir>/myckpt_temp/
256    #        part-?????-of-?????{.index, .data-00000-of-00001}
257    #
258    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
259    #
260    #     <train dir>/
261    #        myckpt{.index, .data-?????-of-?????}
262    #
263    #   Filesystems with eventual consistency (such as S3), don't need a
264    #   temporary location. Using a temporary directory in those cases might
265    #   cause situations where files are not available during copy.
266    #
267    # Users only need to interact with the user-specified prefix, which is
268    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
269    # prefix directly, instead of any physical pathname.  (On failure and
270    # subsequent restore, an outdated and orphaned temporary directory can be
271    # safely removed.)
272    with ops.device("CPU"):
273      _SHARDED_SUFFIX = array_ops.where(
274          string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"),
275          constant_op.constant(".part"),
276          constant_op.constant(os.path.normpath("_temp/part")))
277      tmp_checkpoint_prefix = string_ops.string_join(
278          [checkpoint_prefix, _SHARDED_SUFFIX])
279
280    num_shards = len(per_device)
281    sharded_saves = []
282    sharded_prefixes = []
283    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
284    last_device = None
285    for shard, (device, saveables) in enumerate(per_device):
286      last_device = device
287      with ops.device(saveable_object_util.set_cpu0(device)):
288        sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
289                                                 num_shards_tensor)
290        sharded_prefixes.append(sharded_filename)
291        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
292
293    with ops.control_dependencies([x.op for x in sharded_saves]):
294      # Co-locates the merge step with the last device.
295      with ops.device(saveable_object_util.set_cpu0(last_device)):
296        # V2 format write path consists of a metadata merge step.  Once merged,
297        # attempts to delete the temporary directory, "<user-fed prefix>_temp".
298        merge_step = gen_io_ops.merge_v2_checkpoints(
299            sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
300        with ops.control_dependencies([merge_step]):
301          # Returns the prefix "<user-fed prefix>" only.  DOES NOT include the
302          # sharded spec suffix.
303          return array_ops.identity(checkpoint_prefix)
304
305  def _AddShardedSaveOps(self, filename_tensor, per_device):
306    """Add ops to save the params per shard.
307
308    Args:
309      filename_tensor: a scalar String Tensor.
310      per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
311        returned by _GroupByDevices().
312
313    Returns:
314      An op to save the variables.
315    """
316    if self._write_version == saver_pb2.SaverDef.V2:
317      return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
318
319    num_shards = len(per_device)
320    sharded_saves = []
321    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
322    for shard, (device, saveables) in enumerate(per_device):
323      with ops.device(device):
324        sharded_filename = self.sharded_filename(filename_tensor, shard,
325                                                 num_shards_tensor)
326        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
327    # Return the sharded name for the save path.
328    with ops.control_dependencies([x.op for x in sharded_saves]):
329      return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
330
331  def _AddRestoreOps(self,
332                     filename_tensor,
333                     saveables,
334                     restore_sequentially,
335                     reshape,
336                     preferred_shard=-1,
337                     name="restore_all"):
338    """Add operations to restore saveables.
339
340    Args:
341      filename_tensor: Tensor for the path of the file to load.
342      saveables: A list of SaveableObject objects.
343      restore_sequentially: True if we want to restore variables sequentially
344        within a shard.
345      reshape: True if we want to reshape loaded tensors to the shape of the
346        corresponding variable.
347      preferred_shard: Shard to open first when loading a sharded file.
348      name: Name for the returned op.
349
350    Returns:
351      An Operation that restores the variables.
352    """
353    all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
354                                    restore_sequentially)
355
356    assign_ops = []
357    idx = 0
358    # Load and optionally reshape on the CPU, as string tensors are not
359    # available on the GPU.
360    # TODO(touts): Re-enable restore on GPU when we can support annotating
361    # string tensors as "HostMemory" inputs.
362    for saveable in saveables:
363      shapes = None
364      if reshape:
365        # Compute the shapes, let the restore op decide if and how to do
366        # the reshape.
367        shapes = []
368        for spec in saveable.specs:
369          v = spec.tensor
370          shape = v.get_shape()
371          if not shape.is_fully_defined():
372            shape = array_ops.shape(v)
373          shapes.append(shape)
374      saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
375      idx += len(saveable.specs)
376      assign_ops.append(saveable.restore(saveable_tensors, shapes))
377
378    # Create a Noop that has control dependencies from all the updates.
379    return control_flow_ops.group(*assign_ops, name=name)
380
381  def _AddShardedRestoreOps(self, filename_tensor, per_device,
382                            restore_sequentially, reshape):
383    """Add Ops to restore variables from multiple devices.
384
385    Args:
386      filename_tensor: Tensor for the path of the file to load.
387      per_device: A list of (device, SaveableObject) pairs, as returned by
388        _GroupByDevices().
389      restore_sequentially: True if we want to restore variables sequentially
390        within a shard.
391      reshape: True if we want to reshape loaded tensors to the shape of the
392        corresponding variable.
393
394    Returns:
395      An Operation that restores the variables.
396    """
397    sharded_restores = []
398    for shard, (device, saveables) in enumerate(per_device):
399      with ops.device(device):
400        sharded_restores.append(
401            self._AddRestoreOps(
402                filename_tensor,
403                saveables,
404                restore_sequentially,
405                reshape,
406                preferred_shard=shard,
407                name="restore_shard"))
408    return control_flow_ops.group(*sharded_restores, name="restore_all")
409
410  def _GroupByDevices(self, saveables):
411    """Group Variable tensor slices per device.
412
413    TODO(touts): Make sure that all the devices found are on different
414    job/replica/task/cpu|gpu.  It would be bad if 2 were on the same device.
415    It can happen if the devices are unspecified.
416
417    Args:
418      saveables: A list of BaseSaverBuilder.SaveableObject objects.
419
420    Returns:
421      A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
422      The list is sorted by ascending device_name.
423
424    Raises:
425      ValueError: If the tensors of a saveable are on different devices.
426    """
427    per_device = collections.defaultdict(lambda: [])
428    for saveable in saveables:
429      canonical_device = set(
430          pydev.canonical_name(spec.device) for spec in saveable.specs)
431      if len(canonical_device) != 1:
432        raise ValueError("All tensors of a saveable object must be "
433                         "on the same device: %s" % saveable.name)
434      per_device[canonical_device.pop()].append(saveable)
435    return sorted(per_device.items(), key=lambda t: t[0])
436
437  def build(self,
438            names_to_saveables,
439            reshape=False,
440            sharded=False,
441            max_to_keep=5,
442            keep_checkpoint_every_n_hours=10000.0,
443            name=None,
444            restore_sequentially=False,
445            filename="model"):
446    """Builds save/restore graph nodes or runs save/restore in eager mode.
447
448    Args:
449      names_to_saveables: A dictionary mapping name to a Variable or
450        SaveableObject. Each name will be associated with the corresponding
451        variable in the checkpoint.
452      reshape: If True, allow restoring parameters from a checkpoint that where
453        the parameters have a different shape.  This is only needed when you try
454        to restore from a Dist-Belief checkpoint, and only some times.
455      sharded: If True, shard the checkpoints, one per device that has Variable
456        nodes.
457      max_to_keep: Maximum number of checkpoints to keep.  As new checkpoints
458        are created, old ones are deleted.  If None or 0, no checkpoints are
459        deleted from the filesystem but only the last one is kept in the
460        `checkpoint` file.  Presently the number is only roughly enforced.  For
461        example in case of restarts more than max_to_keep checkpoints may be
462        kept.
463      keep_checkpoint_every_n_hours: How often checkpoints should be kept.
464        Defaults to 10,000 hours.
465      name: String.  Optional name to use as a prefix when adding operations.
466      restore_sequentially: A Bool, which if true, causes restore of different
467        variables to happen sequentially within each device.
468      filename: If known at graph construction time, filename used for variable
469        loading/saving. If None, then the default name "model" will be used.
470
471    Returns:
472      A SaverDef proto.
473
474    Raises:
475      TypeError: If 'names_to_saveables' is not a dictionary mapping string
476        keys to variable Tensors.
477      ValueError: If any of the keys or values in 'names_to_saveables' is not
478        unique.
479    """
480    return self._build_internal(
481        names_to_saveables=names_to_saveables,
482        reshape=reshape,
483        sharded=sharded,
484        max_to_keep=max_to_keep,
485        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
486        name=name,
487        restore_sequentially=restore_sequentially,
488        filename=filename)
489
490  def _build_internal(self,
491                      names_to_saveables,
492                      reshape=False,
493                      sharded=False,
494                      max_to_keep=5,
495                      keep_checkpoint_every_n_hours=10000.0,
496                      name=None,
497                      restore_sequentially=False,
498                      filename="model",
499                      build_save=True,
500                      build_restore=True):
501    """build() with option to only perform save and restore."""
502    if not context.executing_eagerly() and (not build_save or
503                                            not build_restore):
504      raise ValueError("save and restore operations need to be built together "
505                       " when eager execution is not enabled.")
506
507    saveables = saveable_object_util.validate_and_slice_inputs(
508        names_to_saveables)
509    if max_to_keep is None:
510      max_to_keep = 0
511
512    with ops.name_scope(name, "save",
513                        [saveable.op for saveable in saveables]) as name:
514      # Add a placeholder string tensor for the filename.
515      filename_tensor = array_ops.placeholder_with_default(
516          filename or "model", shape=(), name="filename")
517      # Keep the name "Const" for backwards compatibility.
518      filename_tensor = array_ops.placeholder_with_default(
519          filename_tensor, shape=(), name="Const")
520
521      # Add the save ops.
522      if sharded:
523        per_device = self._GroupByDevices(saveables)
524        if build_save:
525          save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
526        if build_restore:
527          restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
528                                                  restore_sequentially, reshape)
529      else:
530        if build_save:
531          save_tensor = self._AddSaveOps(filename_tensor, saveables)
532        if build_restore:
533          restore_op = self._AddRestoreOps(filename_tensor, saveables,
534                                           restore_sequentially, reshape)
535
536    # In the following use case, it's possible to have restore_ops be called
537    # something else:
538    # - Build inference graph and export a meta_graph.
539    # - Import the inference meta_graph
540    # - Extend the inference graph to a train graph.
541    # - Export a new meta_graph.
542    # Now the second restore_op will be called "restore_all_1".
543    # As such, comment out the assert for now until we know whether supporting
544    # such usage model makes sense.
545    #
546    # assert restore_op.name.endswith("restore_all"), restore_op.name
547    if context.executing_eagerly():
548      # Store the tensor values to the tensor_names.
549      save_tensor_name = save_tensor.numpy() if build_save else ""
550      return saver_pb2.SaverDef(
551          filename_tensor_name=filename_tensor.numpy(),
552          save_tensor_name=save_tensor_name,
553          restore_op_name="",
554          max_to_keep=max_to_keep,
555          sharded=sharded,
556          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
557          version=self._write_version)
558    else:
559      graph = ops.get_default_graph()
560      # Do some sanity checking on collections containing
561      # PartitionedVariables. If a saved collection has a PartitionedVariable,
562      # the GraphDef needs to include concat ops to get the value (or there'll
563      # be a lookup error on load).
564      check_collection_list = graph.get_all_collection_keys()
565      for collection_type in check_collection_list:
566        for element in graph.get_collection(collection_type):
567          if isinstance(element, variables.PartitionedVariable):
568            try:
569              graph.get_operation_by_name(element.name)
570            except KeyError:
571              # Create a concat op for this PartitionedVariable. The user may
572              # not need it, but we'll try looking it up on MetaGraph restore
573              # since it's in a collection.
574              element.as_tensor()
575      return saver_pb2.SaverDef(
576          filename_tensor_name=filename_tensor.name,
577          save_tensor_name=save_tensor.name,
578          restore_op_name=restore_op.name,
579          max_to_keep=max_to_keep,
580          sharded=sharded,
581          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
582          version=self._write_version)
583
584
585class BulkSaverBuilder(BaseSaverBuilder):
586  """SaverBuilder with support for bulk restoring multiple saveables."""
587
588  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
589                   restore_sequentially):
590
591    # Ignored: bulk restore is internally sequential.
592    del restore_sequentially
593    restore_specs = []
594    for saveable in saveables:
595      for spec in saveable.specs:
596        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
597
598    names, slices, dtypes = zip(*restore_specs)
599    # Load all tensors onto CPU 0 for compatibility with existing code.
600    with ops.device("cpu:0"):
601      return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
602
603
604def _get_saver_or_default():
605  """Returns the saver from SAVERS collection, or creates a default one.
606
607  This method is used by other members of the training module, such as
608  `Scaffold`, or `CheckpointSaverHook`.
609
610  Returns:
611    `Saver`.
612
613  Raises:
614    RuntimeError: If the SAVERS collection already has more than one items.
615  """
616  collection_key = ops.GraphKeys.SAVERS
617  savers = ops.get_collection(collection_key)
618  if savers:
619    if len(savers) > 1:
620      raise RuntimeError(
621          "More than one item in collection {}. "
622          "Please indicate which one to use by passing it to the constructor."
623          .format(collection_key))
624    return savers[0]
625  saver = Saver(sharded=True, allow_empty=True)
626  if saver is not None:
627    ops.add_to_collection(collection_key, saver)
628  return saver
629
630
631@tf_export(v1=["train.Saver"])
632class Saver(object):
633  # pylint: disable=line-too-long
634  """Saves and restores variables.
635
636  @compatibility(TF2)
637  `tf.compat.v1.train.Saver` is not supported for saving and restoring
638  checkpoints in TF2. Please switch to `tf.train.Checkpoint` or
639  `tf.keras.Model.save_weights`, which perform a more robust [object-based
640  saving](https://www.tensorflow.org/guide/checkpoint#loading_mechanics).
641
642  ### How to Rewrite Checkpoints
643
644  Please rewrite your checkpoints immediately using the object-based checkpoint
645  APIs.
646
647  You can load a name-based checkpoint written by `tf.compat.v1.train.Saver`
648  using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However,
649  you may have to change the names of the variables in your model to match the
650  variable names in the name-based checkpoint, which can be viewed with
651  `tf.train.list_variables(path)`.
652
653  Another option is to create an `assignment_map` that maps the name of the
654  variables in the name-based checkpoint to the variables in your model, eg:
655  ```
656  {
657      'sequential/dense/bias': model.variables[0],
658      'sequential/dense/kernel': model.variables[1]
659  }
660  ```
661  and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to
662  restore the name-based checkpoint.
663
664  After restoring, re-encode your checkpoint
665  using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`.
666
667  See the [Checkpoint compatibility](
668  https://www.tensorflow.org/guide/migrate#checkpoint_compatibility)
669  section of the migration guide for more details.
670
671
672  ### Checkpoint Management in TF2
673
674  Use `tf.train.CheckpointManager` to manage checkpoints in TF2.
675  `tf.train.CheckpointManager` offers equivalent `keep_checkpoint_every_n_hours`
676  and `max_to_keep` parameters.
677
678  To recover the latest checkpoint,
679
680  ```
681  checkpoint = tf.train.Checkpoint(model)
682  manager = tf.train.CheckpointManager(checkpoint)
683  status = checkpoint.restore(manager.latest_checkpoint)
684  ```
685
686  `tf.train.CheckpointManager` also writes a [`CheckpointState` proto]
687  (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/checkpoint_state.proto)
688  which contains the timestamp when each checkpoint was created.
689
690  ### Writing `MetaGraphDef`s in TF2
691
692  To replace, `tf.compat.v1.train.Saver.save(write_meta_graph=True)`, use
693  `tf.saved_model.save` to write the `MetaGraphDef` (which is contained in
694  `saved_model.pb`).
695
696  @end_compatibility
697
698  See [Variables](https://tensorflow.org/guide/variables)
699  for an overview of variables, saving and restoring.
700
701  The `Saver` class adds ops to save and restore variables to and from
702  *checkpoints*.  It also provides convenience methods to run these ops.
703
704  Checkpoints are binary files in a proprietary format which map variable names
705  to tensor values.  The best way to examine the contents of a checkpoint is to
706  load it using a `Saver`.
707
708  Savers can automatically number checkpoint filenames with a provided counter.
709  This lets you keep multiple checkpoints at different steps while training a
710  model.  For example you can number the checkpoint filenames with the training
711  step number.  To avoid filling up disks, savers manage checkpoint files
712  automatically. For example, they can keep only the N most recent files, or
713  one checkpoint for every N hours of training.
714
715  You number checkpoint filenames by passing a value to the optional
716  `global_step` argument to `save()`:
717
718  ```python
719  saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
720  ...
721  saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
722  ```
723
724  Additionally, optional arguments to the `Saver()` constructor let you control
725  the proliferation of checkpoint files on disk:
726
727  * `max_to_keep` indicates the maximum number of recent checkpoint files to
728    keep.  As new files are created, older files are deleted.   If None or 0,
729    no checkpoints are deleted from the filesystem but only the last one is
730    kept in the `checkpoint` file.  Defaults to 5 (that is, the 5 most recent
731    checkpoint files are kept.)
732
733  * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
734    `max_to_keep` checkpoint files, you might want to keep one checkpoint file
735    for every N hours of training.  This can be useful if you want to later
736    analyze how a model progressed during a long training session.  For
737    example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
738    one checkpoint file for every 2 hours of training.  The default value of
739    10,000 hours effectively disables the feature.
740
741  Note that you still have to call the `save()` method to save the model.
742  Passing these arguments to the constructor will not save variables
743  automatically for you.
744
745  A training program that saves regularly looks like:
746
747  ```python
748  ...
749  # Create a saver.
750  saver = tf.compat.v1.train.Saver(...variables...)
751  # Launch the graph and train, saving the model every 1,000 steps.
752  sess = tf.compat.v1.Session()
753  for step in xrange(1000000):
754      sess.run(..training_op..)
755      if step % 1000 == 0:
756          # Append the step number to the checkpoint name:
757          saver.save(sess, 'my-model', global_step=step)
758  ```
759
760  In addition to checkpoint files, savers keep a protocol buffer on disk with
761  the list of recent checkpoints. This is used to manage numbered checkpoint
762  files and by `latest_checkpoint()`, which makes it easy to discover the path
763  to the most recent checkpoint. That protocol buffer is stored in a file named
764  'checkpoint' next to the checkpoint files.
765
766  If you create several savers, you can specify a different filename for the
767  protocol buffer file in the call to `save()`.
768  """
769  # pylint: enable=line-too-long
770
771  def __init__(self,
772               var_list=None,
773               reshape=False,
774               sharded=False,
775               max_to_keep=5,
776               keep_checkpoint_every_n_hours=10000.0,
777               name=None,
778               restore_sequentially=False,
779               saver_def=None,
780               builder=None,
781               defer_build=False,
782               allow_empty=False,
783               write_version=saver_pb2.SaverDef.V2,
784               pad_step_number=False,
785               save_relative_paths=False,
786               filename=None):
787    """Creates a `Saver`.
788
789    The constructor adds ops to save and restore variables.
790
791    `var_list` specifies the variables that will be saved and restored. It can
792    be passed as a `dict` or a list:
793
794    * A `dict` of names to variables: The keys are the names that will be
795      used to save or restore the variables in the checkpoint files.
796    * A list of variables: The variables will be keyed with their op name in
797      the checkpoint files.
798
799    For example:
800
801    ```python
802    v1 = tf.Variable(..., name='v1')
803    v2 = tf.Variable(..., name='v2')
804
805    # Pass the variables as a dict:
806    saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
807
808    # Or pass them as a list.
809    saver = tf.compat.v1.train.Saver([v1, v2])
810    # Passing a list is equivalent to passing a dict with the variable op names
811    # as keys:
812    saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
813    ```
814
815    Note: the newer `AutoTrackable` API is not supported by `Saver`. In this
816    case, the `tf.train.Checkpoint` class should be used.
817
818    The optional `reshape` argument, if `True`, allows restoring a variable from
819    a save file where the variable had a different shape, but the same number
820    of elements and type.  This is useful if you have reshaped a variable and
821    want to reload it from an older checkpoint.
822
823    The optional `sharded` argument, if `True`, instructs the saver to shard
824    checkpoints per device.
825
826    Args:
827      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
828        names to `SaveableObject`s. If `None`, defaults to the list of all
829        saveable objects.
830      reshape: If `True`, allows restoring parameters from a checkpoint where
831        the variables have a different shape.
832      sharded: If `True`, shard the checkpoints, one per device.
833      max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
834      keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
835        10,000 hours.
836      name: String.  Optional name to use as a prefix when adding operations.
837      restore_sequentially: A `Bool`, which if true, causes restore of different
838        variables to happen sequentially within each device.  This can lower
839        memory usage when restoring very large models.
840      saver_def: Optional `SaverDef` proto to use instead of running the
841        builder. This is only useful for specialty code that wants to recreate a
842        `Saver` object for a previously built `Graph` that had a `Saver`. The
843        `saver_def` proto should be the one returned by the `as_saver_def()`
844        call of the `Saver` that was created for that `Graph`.
845      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
846        Defaults to `BulkSaverBuilder()`.
847      defer_build: If `True`, defer adding the save and restore ops to the
848        `build()` call. In that case `build()` should be called before
849        finalizing the graph or using the saver.
850      allow_empty: If `False` (default) raise an error if there are no variables
851        in the graph. Otherwise, construct the saver anyway and make it a no-op.
852      write_version: controls what format to use when saving checkpoints.  It
853        also affects certain filepath matching logic.  The V2 format is the
854        recommended choice: it is much more optimized than V1 in terms of memory
855          required and latency incurred during restore.  Regardless of this
856          flag, the Saver is able to restore from both V2 and V1 checkpoints.
857      pad_step_number: if True, pads the global step number in the checkpoint
858        filepaths to some fixed width (8 by default).  This is turned off by
859        default.
860      save_relative_paths: If `True`, will write relative paths to the
861        checkpoint state file. This is needed if the user wants to copy the
862        checkpoint directory and reload from the copied directory.
863      filename: If known at graph construction time, filename used for variable
864        loading/saving.
865
866    Raises:
867      TypeError: If `var_list` is invalid.
868      ValueError: If any of the keys or values in `var_list` are not unique.
869      RuntimeError: If eager execution is enabled and`var_list` does not specify
870        a list of variables to save.
871
872    @compatibility(eager)
873    When eager execution is enabled, `var_list` must specify a `list` or `dict`
874    of variables to save. Otherwise, a `RuntimeError` will be raised.
875
876    Although Saver works in some cases when executing eagerly, it is
877    fragile. Please switch to `tf.train.Checkpoint` or
878    `tf.keras.Model.save_weights`, which perform a more robust object-based
879    saving. These APIs will load checkpoints written by `Saver`.
880    @end_compatibility
881    """
882    global _END_TIME_OF_LAST_WRITE
883    with _END_TIME_OF_LAST_WRITE_LOCK:
884      if _END_TIME_OF_LAST_WRITE is None:
885        _END_TIME_OF_LAST_WRITE = time.time()
886
887    if defer_build and var_list:
888      raise ValueError(
889          "If `var_list` is provided then build cannot be deferred. "
890          "Either set defer_build=False or var_list=None.")
891    if context.executing_eagerly():
892      logging.warning(
893          "Saver is deprecated, please switch to tf.train.Checkpoint or "
894          "tf.keras.Model.save_weights for training checkpoints. When "
895          "executing eagerly variables do not necessarily have unique names, "
896          "and so the variable.name-based lookups Saver performs are "
897          "error-prone.")
898      if var_list is None:
899        raise RuntimeError(
900            "When eager execution is enabled, `var_list` must specify a list "
901            "or dict of variables to save")
902    self._var_list = var_list
903    self._reshape = reshape
904    self._sharded = sharded
905    self._max_to_keep = max_to_keep
906    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
907    self._name = name
908    self._restore_sequentially = restore_sequentially
909    self.saver_def = saver_def
910    self._builder = builder
911    self._is_built = False
912    self._allow_empty = allow_empty
913    self._is_empty = None
914    self._write_version = write_version
915    self._pad_step_number = pad_step_number
916    self._filename = filename
917    self._last_checkpoints = []
918    self._checkpoints_to_be_deleted = []
919    if context.executing_eagerly():
920      self._next_checkpoint_time = (
921          time.time() + self._keep_checkpoint_every_n_hours * 3600)
922    elif not defer_build:
923      self.build()
924    if self.saver_def:
925      self._check_saver_def()
926      self._write_version = self.saver_def.version
927    self._save_relative_paths = save_relative_paths
928    # For compatibility with object-based checkpoints, we may build a second
929    # Saver to read the renamed keys.
930    self._object_restore_saver = None
931
932  def build(self):
933    if context.executing_eagerly():
934      raise RuntimeError("Use save/restore instead of build in eager mode.")
935    self._build(self._filename, build_save=True, build_restore=True)
936
937  def _build_eager(self, checkpoint_path, build_save, build_restore):
938    self._build(
939        checkpoint_path, build_save=build_save, build_restore=build_restore)
940
941  def _build(self, checkpoint_path, build_save, build_restore):
942    """Builds saver_def."""
943    if not context.executing_eagerly():
944      if self._is_built:
945        return
946      self._is_built = True
947
948    if not self.saver_def or context.executing_eagerly():
949      if self._builder is None:
950        self._builder = BulkSaverBuilder(self._write_version)
951
952      if self._var_list is None:
953        # pylint: disable=protected-access
954        self._var_list = variables._all_saveable_objects()
955      if not self._var_list:
956        if self._allow_empty:
957          self._is_empty = True
958          return
959        else:
960          raise ValueError("No variables to save")
961      self._is_empty = False
962
963      self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
964          self._var_list,
965          reshape=self._reshape,
966          sharded=self._sharded,
967          max_to_keep=self._max_to_keep,
968          keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
969          name=self._name,
970          restore_sequentially=self._restore_sequentially,
971          filename=checkpoint_path,
972          build_save=build_save,
973          build_restore=build_restore)
974    elif self.saver_def and self._name:
975      # Since self._name is used as a name_scope by builder(), we are
976      # overloading the use of this field to represent the "import_scope" as
977      # well.
978      self.saver_def.filename_tensor_name = ops.prepend_name_scope(
979          self.saver_def.filename_tensor_name, self._name)
980      self.saver_def.save_tensor_name = ops.prepend_name_scope(
981          self.saver_def.save_tensor_name, self._name)
982      self.saver_def.restore_op_name = ops.prepend_name_scope(
983          self.saver_def.restore_op_name, self._name)
984
985    self._check_saver_def()
986    if not context.executing_eagerly():
987      # Updates next checkpoint time.
988      # Set in __init__ when executing eagerly.
989      self._next_checkpoint_time = (
990          time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
991
992  def _check_saver_def(self):
993    if not isinstance(self.saver_def, saver_pb2.SaverDef):
994      raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
995                       self.saver_def)
996    if not context.executing_eagerly():
997      if not self.saver_def.save_tensor_name:
998        raise ValueError("saver_def must specify the save_tensor_name: %s" %
999                         str(self.saver_def))
1000      if not self.saver_def.restore_op_name:
1001        raise ValueError("saver_def must specify the restore_op_name: %s" %
1002                         str(self.saver_def))
1003
1004  def _CheckpointFilename(self, p):
1005    """Returns the checkpoint filename given a `(filename, time)` pair.
1006
1007    Args:
1008      p: (filename, time) pair.
1009
1010    Returns:
1011      Checkpoint file name.
1012    """
1013    name, _ = p
1014    return name
1015
1016  def _RecordLastCheckpoint(self, latest_save_path):
1017    """Manages the list of the latest checkpoints."""
1018    if not self.saver_def.max_to_keep:
1019      return
1020    # Remove first from list if the same name was used before.
1021    for p in self._last_checkpoints:
1022      if latest_save_path == self._CheckpointFilename(p):
1023        self._last_checkpoints.remove(p)
1024    # Append new path to list
1025    self._last_checkpoints.append((latest_save_path, time.time()))
1026
1027    # If more than max_to_keep, remove oldest.
1028    if len(self._last_checkpoints) > self.saver_def.max_to_keep:
1029      self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
1030
1031  def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
1032    """Deletes old checkpoints if necessary.
1033
1034    `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
1035    over `max_to_keep`.  They are going to be deleted.  If
1036    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
1037    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
1038    kept for every 0.5 hours of training; if `N` is 10, an additional
1039    checkpoint is kept for every 10 hours of training.
1040
1041    Args:
1042      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1043    """
1044    if self._checkpoints_to_be_deleted:
1045      p = self._checkpoints_to_be_deleted.pop(0)
1046      # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
1047      # have reached N hours of training.
1048      should_keep = p[1] > self._next_checkpoint_time
1049      if should_keep:
1050        self._next_checkpoint_time += (
1051            self.saver_def.keep_checkpoint_every_n_hours * 3600)
1052        return
1053
1054      # Otherwise delete the files.
1055      try:
1056        checkpoint_management.remove_checkpoint(
1057            self._CheckpointFilename(p), self.saver_def.version,
1058            meta_graph_suffix)
1059      except Exception as e:  # pylint: disable=broad-except
1060        logging.warning("Ignoring: %s", str(e))
1061
1062  def as_saver_def(self):
1063    """Generates a `SaverDef` representation of this saver.
1064
1065    Returns:
1066      A `SaverDef` proto.
1067    """
1068    return self.saver_def
1069
1070  def to_proto(self, export_scope=None):
1071    """Converts this `Saver` to a `SaverDef` protocol buffer.
1072
1073    Args:
1074      export_scope: Optional `string`. Name scope to remove.
1075
1076    Returns:
1077      A `SaverDef` protocol buffer.
1078    """
1079    if export_scope is None:
1080      return self.saver_def
1081
1082    if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
1083            self.saver_def.save_tensor_name.startswith(export_scope) and
1084            self.saver_def.restore_op_name.startswith(export_scope)):
1085      return None
1086
1087    saver_def = saver_pb2.SaverDef()
1088    saver_def.CopyFrom(self.saver_def)
1089    saver_def.filename_tensor_name = ops.strip_name_scope(
1090        saver_def.filename_tensor_name, export_scope)
1091    saver_def.save_tensor_name = ops.strip_name_scope(
1092        saver_def.save_tensor_name, export_scope)
1093    saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
1094                                                     export_scope)
1095    return saver_def
1096
1097  @staticmethod
1098  def from_proto(saver_def, import_scope=None):
1099    """Returns a `Saver` object created from `saver_def`.
1100
1101    Args:
1102      saver_def: a `SaverDef` protocol buffer.
1103      import_scope: Optional `string`. Name scope to use.
1104
1105    Returns:
1106      A `Saver` built from saver_def.
1107    """
1108    return Saver(saver_def=saver_def, name=import_scope)
1109
1110  @property
1111  def last_checkpoints(self):
1112    """List of not-yet-deleted checkpoint filenames.
1113
1114    You can pass any of the returned values to `restore()`.
1115
1116    Returns:
1117      A list of checkpoint filenames, sorted from oldest to newest.
1118    """
1119    return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
1120
1121  def set_last_checkpoints(self, last_checkpoints):
1122    """DEPRECATED: Use set_last_checkpoints_with_time.
1123
1124    Sets the list of old checkpoint filenames.
1125
1126    Args:
1127      last_checkpoints: A list of checkpoint filenames.
1128
1129    Raises:
1130      AssertionError: If last_checkpoints is not a list.
1131    """
1132    assert isinstance(last_checkpoints, list)
1133    # We use a timestamp of +inf so that this checkpoint will never be
1134    # deleted.  This is both safe and backwards compatible to a previous
1135    # version of the code which used s[1] as the "timestamp".
1136    self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
1137
1138  def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
1139    """Sets the list of old checkpoint filenames and timestamps.
1140
1141    Args:
1142      last_checkpoints_with_time: A list of tuples of checkpoint filenames and
1143        timestamps.
1144
1145    Raises:
1146      AssertionError: If last_checkpoints_with_time is not a list.
1147    """
1148    assert isinstance(last_checkpoints_with_time, list)
1149    self._last_checkpoints = last_checkpoints_with_time
1150
1151  def recover_last_checkpoints(self, checkpoint_paths):
1152    """Recovers the internal saver state after a crash.
1153
1154    This method is useful for recovering the "self._last_checkpoints" state.
1155
1156    Globs for the checkpoints pointed to by `checkpoint_paths`.  If the files
1157    exist, use their mtime as the checkpoint timestamp.
1158
1159    Args:
1160      checkpoint_paths: a list of checkpoint paths.
1161    """
1162    checkpoints_with_mtimes = []
1163    for checkpoint_path in checkpoint_paths:
1164      try:
1165        mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
1166      except errors.NotFoundError:
1167        # It's fine if some other thread/process is deleting some older
1168        # checkpoint concurrently.
1169        continue
1170      if mtime:
1171        checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
1172    self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
1173
1174  def save(self,
1175           sess,
1176           save_path,
1177           global_step=None,
1178           latest_filename=None,
1179           meta_graph_suffix="meta",
1180           write_meta_graph=True,
1181           write_state=True,
1182           strip_default_attrs=False,
1183           save_debug_info=False):
1184    # pylint: disable=line-too-long
1185    """Saves variables.
1186
1187    This method runs the ops added by the constructor for saving variables.
1188    It requires a session in which the graph was launched.  The variables to
1189    save must also have been initialized.
1190
1191    The method returns the path prefix of the newly created checkpoint files.
1192    This string can be passed directly to a call to `restore()`.
1193
1194    Args:
1195      sess: A Session to use to save the variables.
1196      save_path: String.  Prefix of filenames created for the checkpoint.
1197      global_step: If provided the global step number is appended to `save_path`
1198        to create the checkpoint filenames. The optional argument can be a
1199        `Tensor`, a `Tensor` name or an integer.
1200      latest_filename: Optional name for the protocol buffer file that will
1201        contains the list of most recent checkpoints.  That file, kept in the
1202        same directory as the checkpoint files, is automatically managed by the
1203        saver to keep track of recent checkpoints.  Defaults to 'checkpoint'.
1204      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1205      write_meta_graph: `Boolean` indicating whether or not to write the meta
1206        graph file.
1207      write_state: `Boolean` indicating whether or not to write the
1208        `CheckpointStateProto`.
1209      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1210        removed from the NodeDefs. For a detailed guide, see
1211        [Stripping Default-Valued
1212          Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1213      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1214        which in the same directory of save_path and with `_debug` added before
1215        the file extension. This is only enabled when `write_meta_graph` is
1216        `True`
1217
1218    Returns:
1219      A string: path prefix used for the checkpoint files.  If the saver is
1220        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
1221        is the number of shards created.
1222      If the saver is empty, returns None.
1223
1224    Raises:
1225      TypeError: If `sess` is not a `Session`.
1226      ValueError: If `latest_filename` contains path components, or if it
1227        collides with `save_path`.
1228      RuntimeError: If save and restore ops weren't built.
1229    """
1230    # pylint: enable=line-too-long
1231    start_time = time.time()
1232    if not self._is_built and not context.executing_eagerly():
1233      raise RuntimeError(
1234          "`build()` should be called before save if defer_build==True")
1235    if latest_filename is None:
1236      latest_filename = "checkpoint"
1237    if self._write_version != saver_pb2.SaverDef.V2:
1238      logging.warning("*******************************************************")
1239      logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
1240      logging.warning("Consider switching to the more efficient V2 format:")
1241      logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
1242      logging.warning("now on by default.")
1243      logging.warning("*******************************************************")
1244
1245    if os.path.split(latest_filename)[0]:
1246      raise ValueError("'latest_filename' must not contain path components")
1247
1248    save_path = compat.as_str(save_path)
1249    if global_step is not None:
1250      if not isinstance(global_step, compat.integral_types):
1251        global_step = training_util.global_step(sess, global_step)
1252      checkpoint_file = "%s-%d" % (save_path, global_step)
1253      if self._pad_step_number:
1254        # Zero-pads the step numbers, so that they are sorted when listed.
1255        checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
1256    else:
1257      checkpoint_file = save_path
1258      if os.path.basename(save_path) == latest_filename and not self._sharded:
1259        # Guard against collision between data file and checkpoint state file.
1260        raise ValueError(
1261            "'latest_filename' collides with 'save_path': '%s' and '%s'" %
1262            (latest_filename, save_path))
1263
1264    if (not context.executing_eagerly() and
1265        not isinstance(sess, session.SessionInterface)):
1266      raise TypeError("'sess' must be a Session; %s" % sess)
1267
1268    save_path_parent = os.path.dirname(save_path)
1269    if not self._is_empty:
1270      try:
1271        if context.executing_eagerly():
1272          self._build_eager(
1273              checkpoint_file, build_save=True, build_restore=False)
1274          model_checkpoint_path = self.saver_def.save_tensor_name
1275        else:
1276          model_checkpoint_path = sess.run(
1277              self.saver_def.save_tensor_name,
1278              {self.saver_def.filename_tensor_name: checkpoint_file})
1279
1280        model_checkpoint_path = compat.as_str(model_checkpoint_path)
1281        if write_state:
1282          self._RecordLastCheckpoint(model_checkpoint_path)
1283          checkpoint_management.update_checkpoint_state_internal(
1284              save_dir=save_path_parent,
1285              model_checkpoint_path=model_checkpoint_path,
1286              all_model_checkpoint_paths=self.last_checkpoints,
1287              latest_filename=latest_filename,
1288              save_relative_paths=self._save_relative_paths)
1289          self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
1290      except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
1291        if not gfile.IsDirectory(save_path_parent):
1292          exc = ValueError(
1293              "Parent directory of {} doesn't exist, can't save.".format(
1294                  save_path))
1295        raise exc
1296
1297    end_time = time.time()
1298    metrics.AddCheckpointWriteDuration(
1299        api_label=_SAVER_LABEL,
1300        microseconds=_get_duration_microseconds(start_time, end_time))
1301    global _END_TIME_OF_LAST_WRITE
1302    with _END_TIME_OF_LAST_WRITE_LOCK:
1303      metrics.AddTrainingTimeSaved(
1304          api_label=_SAVER_LABEL,
1305          microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE,
1306                                                  end_time))
1307      _END_TIME_OF_LAST_WRITE = end_time
1308
1309    if write_meta_graph:
1310      meta_graph_filename = checkpoint_management.meta_graph_filename(
1311          checkpoint_file, meta_graph_suffix=meta_graph_suffix)
1312      if not context.executing_eagerly():
1313        with sess.graph.as_default():
1314          self.export_meta_graph(
1315              meta_graph_filename,
1316              strip_default_attrs=strip_default_attrs,
1317              save_debug_info=save_debug_info)
1318
1319    if self._is_empty:
1320      return None
1321    else:
1322      return model_checkpoint_path
1323
1324  def export_meta_graph(self,
1325                        filename=None,
1326                        collection_list=None,
1327                        as_text=False,
1328                        export_scope=None,
1329                        clear_devices=False,
1330                        clear_extraneous_savers=False,
1331                        strip_default_attrs=False,
1332                        save_debug_info=False):
1333    # pylint: disable=line-too-long
1334    """Writes `MetaGraphDef` to save_path/filename.
1335
1336    Args:
1337      filename: Optional meta_graph filename including the path.
1338      collection_list: List of string keys to collect.
1339      as_text: If `True`, writes the meta_graph as an ASCII proto.
1340      export_scope: Optional `string`. Name scope to remove.
1341      clear_devices: Whether or not to clear the device field for an `Operation`
1342        or `Tensor` during export.
1343      clear_extraneous_savers: Remove any Saver-related information from the
1344        graph (both Save/Restore ops and SaverDefs) that are not associated with
1345        this Saver.
1346      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1347        removed from the NodeDefs. For a detailed guide, see
1348        [Stripping Default-Valued
1349          Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1350      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1351        which in the same directory of filename and with `_debug` added before
1352        the file extension.
1353
1354    Returns:
1355      A `MetaGraphDef` proto.
1356    """
1357    # pylint: enable=line-too-long
1358    return export_meta_graph(
1359        filename=filename,
1360        graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
1361        saver_def=self.saver_def,
1362        collection_list=collection_list,
1363        as_text=as_text,
1364        export_scope=export_scope,
1365        clear_devices=clear_devices,
1366        clear_extraneous_savers=clear_extraneous_savers,
1367        strip_default_attrs=strip_default_attrs,
1368        save_debug_info=save_debug_info)
1369
1370  def restore(self, sess, save_path):
1371    """Restores previously saved variables.
1372
1373    This method runs the ops added by the constructor for restoring variables.
1374    It requires a session in which the graph was launched.  The variables to
1375    restore do not have to have been initialized, as restoring is itself a way
1376    to initialize variables.
1377
1378    The `save_path` argument is typically a value previously returned from a
1379    `save()` call, or a call to `latest_checkpoint()`.
1380
1381    Args:
1382      sess: A `Session` to use to restore the parameters. None in eager mode.
1383      save_path: Path where parameters were previously saved.
1384
1385    Raises:
1386      ValueError: If save_path is None or not a valid checkpoint.
1387    """
1388    start_time = time.time()
1389    if self._is_empty:
1390      return
1391    if save_path is None:
1392      raise ValueError("Can't load save_path when it is None.")
1393
1394    checkpoint_prefix = compat.as_text(save_path)
1395    if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix):
1396      raise ValueError("The passed save_path is not a valid checkpoint: " +
1397                       checkpoint_prefix)
1398
1399    logging.info("Restoring parameters from %s", checkpoint_prefix)
1400    try:
1401      if context.executing_eagerly():
1402        self._build_eager(save_path, build_save=False, build_restore=True)
1403      else:
1404        sess.run(self.saver_def.restore_op_name,
1405                 {self.saver_def.filename_tensor_name: save_path})
1406    except errors.NotFoundError as err:
1407      # There are three common conditions that might cause this error:
1408      # 0. The file is missing. We ignore here, as this is checked above.
1409      # 1. This is an object-based checkpoint trying name-based loading.
1410      # 2. The graph has been altered and a variable or other name is missing.
1411
1412      # 1. The checkpoint would not be loaded successfully as is. Try to parse
1413      # it as an object-based checkpoint.
1414      try:
1415        names_to_keys = object_graph_key_mapping(save_path)
1416      except errors.NotFoundError:
1417        # 2. This is not an object-based checkpoint, which likely means there
1418        # is a graph mismatch. Re-raise the original error with
1419        # a helpful message (b/110263146)
1420        raise _wrap_restore_error_with_msg(
1421            err, "a Variable name or other graph key that is missing")
1422
1423      # This is an object-based checkpoint. We'll print a warning and then do
1424      # the restore.
1425      logging.warning(
1426          "Restoring an object-based checkpoint using a name-based saver. This "
1427          "may be somewhat fragile, and will re-build the Saver. Instead, "
1428          "consider loading object-based checkpoints using "
1429          "tf.train.Checkpoint().")
1430      self._object_restore_saver = saver_from_object_based_checkpoint(
1431          checkpoint_path=save_path,
1432          var_list=self._var_list,
1433          builder=self._builder,
1434          names_to_keys=names_to_keys,
1435          cached_saver=self._object_restore_saver)
1436      self._object_restore_saver.restore(sess=sess, save_path=save_path)
1437    except errors.InvalidArgumentError as err:
1438      # There is a mismatch between the graph and the checkpoint being loaded.
1439      # We add a more reasonable error message here to help users (b/110263146)
1440      raise _wrap_restore_error_with_msg(
1441          err, "a mismatch between the current graph and the graph")
1442    metrics.AddCheckpointReadDuration(
1443        api_label=_SAVER_LABEL,
1444        microseconds=_get_duration_microseconds(start_time, time.time()))
1445
1446  @staticmethod
1447  def _add_collection_def(meta_graph_def, key, export_scope=None):
1448    """Adds a collection to MetaGraphDef protocol buffer.
1449
1450    Args:
1451      meta_graph_def: MetaGraphDef protocol buffer.
1452      key: One of the GraphKeys or user-defined string.
1453      export_scope: Optional `string`. Name scope to remove.
1454    """
1455    meta_graph.add_collection_def(
1456        meta_graph_def, key, export_scope=export_scope)
1457
1458
1459@tf_export(v1=["train.import_meta_graph"])
1460def import_meta_graph(meta_graph_or_file,
1461                      clear_devices=False,
1462                      import_scope=None,
1463                      **kwargs):
1464  """Recreates a Graph saved in a `MetaGraphDef` proto.
1465
1466  This function takes a `MetaGraphDef` protocol buffer as input. If
1467  the argument is a file containing a `MetaGraphDef` protocol buffer ,
1468  it constructs a protocol buffer from the file content. The function
1469  then adds all the nodes from the `graph_def` field to the
1470  current graph, recreates all the collections, and returns a saver
1471  constructed from the `saver_def` field.
1472
1473  In combination with `export_meta_graph()`, this function can be used to
1474
1475  * Serialize a graph along with other Python objects such as `QueueRunner`,
1476    `Variable` into a `MetaGraphDef`.
1477
1478  * Restart training from a saved graph and checkpoints.
1479
1480  * Run inference from a saved graph and checkpoints.
1481
1482  ```Python
1483  ...
1484  # Create a saver.
1485  saver = tf.compat.v1.train.Saver(...variables...)
1486  # Remember the training_op we want to run by adding it to a collection.
1487  tf.compat.v1.add_to_collection('train_op', train_op)
1488  sess = tf.compat.v1.Session()
1489  for step in xrange(1000000):
1490      sess.run(train_op)
1491      if step % 1000 == 0:
1492          # Saves checkpoint, which by default also exports a meta_graph
1493          # named 'my-model-global_step.meta'.
1494          saver.save(sess, 'my-model', global_step=step)
1495  ```
1496
1497  Later we can continue training from this saved `meta_graph` without building
1498  the model from scratch.
1499
1500  ```Python
1501  with tf.Session() as sess:
1502    new_saver =
1503    tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
1504    new_saver.restore(sess, 'my-save-dir/my-model-10000')
1505    # tf.get_collection() returns a list. In this example we only want
1506    # the first one.
1507    train_op = tf.get_collection('train_op')[0]
1508    for step in xrange(1000000):
1509      sess.run(train_op)
1510  ```
1511
1512  NOTE: Restarting training from saved `meta_graph` only works if the
1513  device assignments have not changed.
1514
1515  Example:
1516  Variables, placeholders, and independent operations can also be stored, as
1517  shown in the following example.
1518
1519  ```Python
1520  # Saving contents and operations.
1521  v1 = tf.placeholder(tf.float32, name="v1")
1522  v2 = tf.placeholder(tf.float32, name="v2")
1523  v3 = tf.math.multiply(v1, v2)
1524  vx = tf.Variable(10.0, name="vx")
1525  v4 = tf.add(v3, vx, name="v4")
1526  saver = tf.train.Saver([vx])
1527  sess = tf.Session()
1528  sess.run(tf.global_variables_initializer())
1529  sess.run(vx.assign(tf.add(vx, vx)))
1530  result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
1531  print(result)
1532  saver.save(sess, "./model_ex1")
1533  ```
1534
1535  Later this model can be restored and contents loaded.
1536
1537  ```Python
1538  # Restoring variables and running operations.
1539  saver = tf.train.import_meta_graph("./model_ex1.meta")
1540  sess = tf.Session()
1541  saver.restore(sess, "./model_ex1")
1542  result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
1543  print(result)
1544  ```
1545
1546  Args:
1547    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
1548      the path) containing a `MetaGraphDef`.
1549    clear_devices: Whether or not to clear the device field for an `Operation`
1550      or `Tensor` during import.
1551    import_scope: Optional `string`. Name scope to add. Only used when
1552      initializing from protocol buffer.
1553    **kwargs: Optional keyed arguments.
1554
1555  Returns:
1556    A saver constructed from `saver_def` in `MetaGraphDef` or None.
1557
1558    A None value is returned if no variables exist in the `MetaGraphDef`
1559    (i.e., there are no variables to restore).
1560
1561  Raises:
1562    RuntimeError: If called with eager execution enabled.
1563
1564  @compatibility(eager)
1565  Exporting/importing meta graphs is not supported. No graph exists when eager
1566  execution is enabled.
1567  @end_compatibility
1568  """  # pylint: disable=g-doc-exception
1569  return _import_meta_graph_with_return_elements(meta_graph_or_file,
1570                                                 clear_devices, import_scope,
1571                                                 **kwargs)[0]
1572
1573
1574def _import_meta_graph_with_return_elements(meta_graph_or_file,
1575                                            clear_devices=False,
1576                                            import_scope=None,
1577                                            return_elements=None,
1578                                            **kwargs):
1579  """Import MetaGraph, and return both a saver and returned elements."""
1580  if context.executing_eagerly():
1581    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1582                       "eager execution is enabled. No graph exists when eager "
1583                       "execution is enabled.")
1584  if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
1585    meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
1586  else:
1587    meta_graph_def = meta_graph_or_file
1588
1589  imported_vars, imported_return_elements = (
1590      meta_graph.import_scoped_meta_graph_with_return_elements(
1591          meta_graph_def,
1592          clear_devices=clear_devices,
1593          import_scope=import_scope,
1594          return_elements=return_elements,
1595          **kwargs))
1596
1597  saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1598                                                 imported_vars)
1599  return saver, imported_return_elements
1600
1601
1602def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1603                                           imported_vars):
1604  """Return a saver for restoring variable values to an imported MetaGraph."""
1605  if meta_graph_def.HasField("saver_def"):
1606    # Infer the scope that is prepended by `import_scoped_meta_graph`.
1607    scope = import_scope
1608    var_names = list(imported_vars.keys())
1609    if var_names:
1610      sample_key = var_names[0]
1611      sample_var = imported_vars[sample_key]
1612      scope = sample_var.name[:-len(sample_key)]
1613
1614    return Saver(saver_def=meta_graph_def.saver_def, name=scope)
1615  else:
1616    if variables._all_saveable_objects(scope=import_scope):  # pylint: disable=protected-access
1617      # Return the default saver instance for all graph variables.
1618      return Saver()
1619    else:
1620      # If no graph variables exist, then a Saver cannot be constructed.
1621      logging.info("Saver not created because there are no variables in the"
1622                   " graph to restore")
1623      return None
1624
1625
1626@tf_export(v1=["train.export_meta_graph"])
1627def export_meta_graph(filename=None,
1628                      meta_info_def=None,
1629                      graph_def=None,
1630                      saver_def=None,
1631                      collection_list=None,
1632                      as_text=False,
1633                      graph=None,
1634                      export_scope=None,
1635                      clear_devices=False,
1636                      clear_extraneous_savers=False,
1637                      strip_default_attrs=False,
1638                      save_debug_info=False,
1639                      **kwargs):
1640  # pylint: disable=line-too-long
1641  """Returns `MetaGraphDef` proto.
1642
1643  Optionally writes it to filename.
1644
1645  This function exports the graph, saver, and collection objects into
1646  `MetaGraphDef` protocol buffer with the intention of it being imported
1647  at a later time or location to restart training, run inference, or be
1648  a subgraph.
1649
1650  Args:
1651    filename: Optional filename including the path for writing the generated
1652      `MetaGraphDef` protocol buffer.
1653    meta_info_def: `MetaInfoDef` protocol buffer.
1654    graph_def: `GraphDef` protocol buffer.
1655    saver_def: `SaverDef` protocol buffer.
1656    collection_list: List of string keys to collect.
1657    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
1658    graph: The `Graph` to export. If `None`, use the default graph.
1659    export_scope: Optional `string`. Name scope under which to extract the
1660      subgraph. The scope name will be striped from the node definitions for
1661      easy import later into new name scopes. If `None`, the whole graph is
1662      exported. graph_def and export_scope cannot both be specified.
1663    clear_devices: Whether or not to clear the device field for an `Operation`
1664      or `Tensor` during export.
1665    clear_extraneous_savers: Remove any Saver-related information from the graph
1666      (both Save/Restore ops and SaverDefs) that are not associated with the
1667      provided SaverDef.
1668    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1669      removed from the NodeDefs. For a detailed guide, see
1670      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1671    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1672      which in the same directory of filename and with `_debug` added before the
1673      file extend.
1674    **kwargs: Optional keyed arguments.
1675
1676  Returns:
1677    A `MetaGraphDef` proto.
1678
1679  Raises:
1680    ValueError: When the `GraphDef` is larger than 2GB.
1681    RuntimeError: If called with eager execution enabled.
1682
1683  @compatibility(eager)
1684  Exporting/importing meta graphs is not supported unless both `graph_def` and
1685  `graph` are provided. No graph exists when eager execution is enabled.
1686  @end_compatibility
1687  """
1688  # pylint: enable=line-too-long
1689  if context.executing_eagerly() and not (graph_def is not None and
1690                                          graph is not None):
1691    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1692                       "eager execution is enabled. No graph exists when eager "
1693                       "execution is enabled.")
1694  meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
1695      filename=filename,
1696      meta_info_def=meta_info_def,
1697      graph_def=graph_def,
1698      saver_def=saver_def,
1699      collection_list=collection_list,
1700      as_text=as_text,
1701      graph=graph,
1702      export_scope=export_scope,
1703      clear_devices=clear_devices,
1704      clear_extraneous_savers=clear_extraneous_savers,
1705      strip_default_attrs=strip_default_attrs,
1706      save_debug_info=save_debug_info,
1707      **kwargs)
1708  return meta_graph_def
1709
1710
1711def _wrap_restore_error_with_msg(err, extra_verbiage):
1712  err_msg = ("Restoring from checkpoint failed. This is most likely "
1713             "due to {} from the checkpoint. Please ensure that you "
1714             "have not altered the graph expected based on the checkpoint. "
1715             "Original error:\n\n{}").format(extra_verbiage, err.message)
1716  return err.__class__(err.node_def, err.op, err_msg)
1717
1718
1719ops.register_proto_function(
1720    ops.GraphKeys.SAVERS,
1721    proto_type=saver_pb2.SaverDef,
1722    to_proto=Saver.to_proto,
1723    from_proto=Saver.from_proto)
1724
1725
1726def object_graph_key_mapping(checkpoint_path):
1727  """Return name to key mappings from the checkpoint.
1728
1729  Args:
1730    checkpoint_path: string, path to object-based checkpoint
1731
1732  Returns:
1733    Dictionary mapping tensor names to checkpoint keys.
1734  """
1735  reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
1736  object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
1737  object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1738  object_graph_proto.ParseFromString(object_graph_string)
1739  names_to_keys = {}
1740  for node in object_graph_proto.nodes:
1741    for attribute in node.attributes:
1742      names_to_keys[attribute.full_name] = attribute.checkpoint_key
1743  return names_to_keys
1744
1745
1746def saver_from_object_based_checkpoint(checkpoint_path,
1747                                       var_list=None,
1748                                       builder=None,
1749                                       names_to_keys=None,
1750                                       cached_saver=None):
1751  """Return a `Saver` which reads from an object-based checkpoint.
1752
1753  This function validates that all variables in the variables list are remapped
1754  in the object-based checkpoint (or `names_to_keys` dict if provided). A
1755  saver will be created with the list of remapped variables.
1756
1757  The `cached_saver` argument allows the user to pass in a previously created
1758  saver, so multiple `saver.restore()` calls don't pollute the graph when graph
1759  building. This assumes that keys are consistent, meaning that the
1760    1) `checkpoint_path` checkpoint, and
1761    2) checkpoint used to create the `cached_saver`
1762  are the same type of object-based checkpoint. If this argument is set, this
1763  function will simply validate that all variables have been remapped by the
1764  checkpoint at `checkpoint_path`.
1765
1766  Note that in general, `tf.train.Checkpoint` should be used to restore/save an
1767  object-based checkpoint.
1768
1769  Args:
1770    checkpoint_path: string, path to object-based checkpoint
1771    var_list: list of `Variables` that appear in the checkpoint. If `None`,
1772      `var_list` will be set to all saveable objects.
1773    builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder`
1774      will be created.
1775    names_to_keys: dict mapping string tensor names to checkpoint keys. If
1776      `None`, this dict will be generated from the checkpoint file.
1777    cached_saver: Cached `Saver` object with remapped variables.
1778
1779  Returns:
1780    `Saver` with remapped variables for reading from an object-based checkpoint.
1781
1782  Raises:
1783    ValueError if the checkpoint provided is not an object-based checkpoint.
1784    NotFoundError: If one of the variables in `var_list` can not be found in the
1785      checkpoint. This could mean the checkpoint or `names_to_keys` mapping is
1786      missing the variable.
1787  """
1788  if names_to_keys is None:
1789    try:
1790      names_to_keys = object_graph_key_mapping(checkpoint_path)
1791    except errors.NotFoundError:
1792      raise ValueError("Checkpoint in %s not an object-based checkpoint." %
1793                       checkpoint_path)
1794  if var_list is None:
1795    var_list = variables._all_saveable_objects()  # pylint: disable=protected-access
1796  if builder is None:
1797    builder = BulkSaverBuilder()
1798
1799  saveables = saveable_object_util.validate_and_slice_inputs(var_list)
1800  current_names = set()
1801  for saveable in saveables:
1802    for spec in saveable.specs:
1803      current_names.add(spec.name)
1804  previous_names = set(names_to_keys.keys())
1805  missing_names = current_names - previous_names
1806  if missing_names:
1807    extra_names = previous_names - current_names
1808    intersecting_names = previous_names.intersection(current_names)
1809    raise errors.NotFoundError(
1810        None,
1811        None,
1812        message=(
1813            "\n\nExisting variables not in the checkpoint: %s\n\n"
1814            "Variables names when this checkpoint was written which don't "
1815            "exist now: %s\n\n"
1816            "(%d variable name(s) did match)\n\n"
1817            "Could not find some variables in the checkpoint (see names "
1818            "above). Saver was attempting to load an object-based checkpoint "
1819            "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) "
1820            "using variable names. If the checkpoint was written with eager "
1821            "execution enabled, it's possible that variable names have "
1822            "changed (for example missing a '_1' suffix). It's also "
1823            "possible that there are new variables which did not exist "
1824            "when the checkpoint was written. You can construct a "
1825            "Saver(var_list=...) with only the variables which previously "
1826            "existed, and if variable names have changed you may need to "
1827            "make this a dictionary with the old names as keys. If you're "
1828            "using an Estimator, you'll need to return a tf.train.Saver "
1829            "inside a tf.train.Scaffold from your model_fn.") %
1830        (", ".join(sorted(missing_names)), ", ".join(
1831            sorted(extra_names)), len(intersecting_names)))
1832  for saveable in saveables:
1833    for spec in saveable.specs:
1834      spec.name = names_to_keys[spec.name]
1835  if cached_saver is None:
1836    return Saver(saveables)
1837  return cached_saver
1838