• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables.
18
19Symbols in this file are deprecated. See replacements in
20tensorflow/python/training/trackable and tensorflow/python/training/saving.
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27import os.path
28import time
29import uuid
30
31import numpy as np
32from tensorflow.core.protobuf import meta_graph_pb2
33from tensorflow.core.protobuf import saver_pb2
34from tensorflow.core.protobuf import trackable_object_graph_pb2
35from tensorflow.python.client import session
36from tensorflow.python.eager import context
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import device as pydev
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import meta_graph
41from tensorflow.python.framework import ops
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import gen_io_ops
45from tensorflow.python.ops import io_ops
46from tensorflow.python.ops import string_ops
47from tensorflow.python.ops import variables
48from tensorflow.python.platform import gfile
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.training import checkpoint_management
51from tensorflow.python.training import py_checkpoint_reader
52from tensorflow.python.training import training_util
53from tensorflow.python.training.saving import saveable_object
54from tensorflow.python.training.saving import saveable_object_util
55from tensorflow.python.training.tracking import base as trackable
56from tensorflow.python.util import compat
57from tensorflow.python.util.tf_export import tf_export
58
59# TODO(allenl): Remove these aliases once all users are migrated off.
60get_checkpoint_state = checkpoint_management.get_checkpoint_state
61update_checkpoint_state = checkpoint_management.update_checkpoint_state
62generate_checkpoint_state_proto = (
63    checkpoint_management.generate_checkpoint_state_proto)
64latest_checkpoint = checkpoint_management.latest_checkpoint
65checkpoint_exists = checkpoint_management.checkpoint_exists
66get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
67remove_checkpoint = checkpoint_management.remove_checkpoint
68
69
70class BaseSaverBuilder(object):
71  """Base class for Savers.
72
73  Can be extended to create different Ops.
74  """
75
76  SaveSpec = saveable_object.SaveSpec
77  SaveableObject = saveable_object.SaveableObject
78
79  # Aliases for code which was moved but still has lots of users.
80  VariableSaveable = saveable_object_util.ReferenceVariableSaveable
81  ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
82
83  def __init__(self, write_version=saver_pb2.SaverDef.V2):
84    self._write_version = write_version
85
86  def save_op(self, filename_tensor, saveables):
87    """Create an Op to save 'saveables'.
88
89    This is intended to be overridden by subclasses that want to generate
90    different Ops.
91
92    Args:
93      filename_tensor: String Tensor.
94      saveables: A list of BaseSaverBuilder.SaveableObject objects.
95
96    Returns:
97      An Operation that save the variables.
98
99    Raises:
100      RuntimeError: (implementation detail) if "self._write_version" is an
101        unexpected value.
102    """
103    # pylint: disable=protected-access
104    tensor_names = []
105    tensors = []
106    tensor_slices = []
107    for saveable in saveables:
108      for spec in saveable.specs:
109        tensor_names.append(spec.name)
110        tensors.append(spec.tensor)
111        tensor_slices.append(spec.slice_spec)
112    if self._write_version == saver_pb2.SaverDef.V1:
113      return io_ops._save(
114          filename=filename_tensor,
115          tensor_names=tensor_names,
116          tensors=tensors,
117          tensor_slices=tensor_slices)
118    elif self._write_version == saver_pb2.SaverDef.V2:
119      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
120      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
121      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
122                            tensors)
123    else:
124      raise RuntimeError("Unexpected write_version: " + self._write_version)
125
126  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
127                   restore_sequentially):
128    """Restore all tensors contained in saveables.
129
130    By default, this issues separate calls to `restore_op` for each saveable.
131    Subclasses may override to load multiple saveables in a single call.
132
133    Args:
134      filename_tensor: String Tensor.
135      saveables: List of BaseSaverBuilder.SaveableObject objects.
136      preferred_shard: Int.  Shard to open first when loading a sharded file.
137      restore_sequentially: Unused.  Bool.  If true, each restore is sequential.
138
139    Returns:
140      A list of Tensors resulting from reading 'saveable' from
141        'filename'.
142
143    """
144    del restore_sequentially
145    all_tensors = []
146    for saveable in saveables:
147      if saveable.device:
148        device = saveable_object_util.set_cpu0(saveable.device)
149      else:
150        device = None
151      with ops.device(device):
152        all_tensors.extend(
153            self.restore_op(filename_tensor, saveable, preferred_shard))
154    return all_tensors
155
156  # pylint: disable=unused-argument
157  def restore_op(self, filename_tensor, saveable, preferred_shard):
158    """Create ops to restore 'saveable'.
159
160    This is intended to be overridden by subclasses that want to generate
161    different Ops.
162
163    Args:
164      filename_tensor: String Tensor.
165      saveable: A BaseSaverBuilder.SaveableObject object.
166      preferred_shard: Int.  Shard to open first when loading a sharded file.
167
168    Returns:
169      A list of Tensors resulting from reading 'saveable' from
170        'filename'.
171    """
172    # pylint: disable=protected-access
173    tensors = []
174    for spec in saveable.specs:
175      tensors.append(
176          io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
177                            [spec.dtype])[0])
178
179    return tensors
180
181  # pylint: enable=unused-argument
182
183  def sharded_filename(self, filename_tensor, shard, num_shards):
184    """Append sharding information to a filename.
185
186    Args:
187      filename_tensor: A string tensor.
188      shard: Integer.  The shard for the filename.
189      num_shards: An int Tensor for the number of shards.
190
191    Returns:
192      A string tensor.
193    """
194    return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
195
196  def _AddSaveOps(self, filename_tensor, saveables):
197    """Add ops to save variables that are on the same shard.
198
199    Args:
200      filename_tensor: String Tensor.
201      saveables: A list of SaveableObject objects.
202
203    Returns:
204      A tensor with the filename used to save.
205    """
206    save = self.save_op(filename_tensor, saveables)
207    return control_flow_ops.with_dependencies([save], filename_tensor)
208
209  def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
210    """Add ops to save the params per shard, for the V2 format.
211
212    Note that the sharded save procedure for the V2 format is different from
213    V1: there is a special "merge" step that merges the small metadata produced
214    from each device.
215
216    Args:
217      checkpoint_prefix: scalar String Tensor.  Interpreted *NOT AS A FILENAME*,
218        but as a prefix of a V2 checkpoint;
219      per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
220        returned by _GroupByDevices().
221
222    Returns:
223      An op to save the variables, which, when evaluated, returns the prefix
224        "<user-fed prefix>" only and does not include the sharded spec suffix.
225    """
226    # IMPLEMENTATION DETAILS: most clients should skip.
227    #
228    # Suffix for any well-formed "checkpoint_prefix", when sharded.
229    # Transformations:
230    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
231    # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
232    #
233    # Example:
234    #   During runtime, a temporary directory is first created, which contains
235    #   files
236    #
237    #     <train dir>/myckpt_temp/
238    #        part-?????-of-?????{.index, .data-00000-of-00001}
239    #
240    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
241    #
242    #     <train dir>/
243    #        myckpt{.index, .data-?????-of-?????}
244    #
245    # Users only need to interact with the user-specified prefix, which is
246    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
247    # prefix directly, instead of any physical pathname.  (On failure and
248    # subsequent restore, an outdated and orphaned temporary directory can be
249    # safely removed.)
250    _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex
251    tmp_checkpoint_prefix = string_ops.string_join(
252        [checkpoint_prefix, _SHARDED_SUFFIX])
253
254    num_shards = len(per_device)
255    sharded_saves = []
256    sharded_prefixes = []
257    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
258    last_device = None
259    for shard, (device, saveables) in enumerate(per_device):
260      last_device = device
261      with ops.device(saveable_object_util.set_cpu0(device)):
262        sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
263                                                 num_shards_tensor)
264        sharded_prefixes.append(sharded_filename)
265        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
266
267    with ops.control_dependencies([x.op for x in sharded_saves]):
268      # Co-locates the merge step with the last device.
269      with ops.device(saveable_object_util.set_cpu0(last_device)):
270        # V2 format write path consists of a metadata merge step.  Once merged,
271        # attempts to delete the temporary directory, "<user-fed prefix>_temp".
272        merge_step = gen_io_ops.merge_v2_checkpoints(
273            sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
274        with ops.control_dependencies([merge_step]):
275          # Returns the prefix "<user-fed prefix>" only.  DOES NOT include the
276          # sharded spec suffix.
277          return array_ops.identity(checkpoint_prefix)
278
279  def _AddShardedSaveOps(self, filename_tensor, per_device):
280    """Add ops to save the params per shard.
281
282    Args:
283      filename_tensor: a scalar String Tensor.
284      per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
285        returned by _GroupByDevices().
286
287    Returns:
288      An op to save the variables.
289    """
290    if self._write_version == saver_pb2.SaverDef.V2:
291      return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
292
293    num_shards = len(per_device)
294    sharded_saves = []
295    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
296    for shard, (device, saveables) in enumerate(per_device):
297      with ops.device(device):
298        sharded_filename = self.sharded_filename(filename_tensor, shard,
299                                                 num_shards_tensor)
300        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
301    # Return the sharded name for the save path.
302    with ops.control_dependencies([x.op for x in sharded_saves]):
303      return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
304
305  def _AddRestoreOps(self,
306                     filename_tensor,
307                     saveables,
308                     restore_sequentially,
309                     reshape,
310                     preferred_shard=-1,
311                     name="restore_all"):
312    """Add operations to restore saveables.
313
314    Args:
315      filename_tensor: Tensor for the path of the file to load.
316      saveables: A list of SaveableObject objects.
317      restore_sequentially: True if we want to restore variables sequentially
318        within a shard.
319      reshape: True if we want to reshape loaded tensors to the shape of the
320        corresponding variable.
321      preferred_shard: Shard to open first when loading a sharded file.
322      name: Name for the returned op.
323
324    Returns:
325      An Operation that restores the variables.
326    """
327    all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
328                                    restore_sequentially)
329
330    assign_ops = []
331    idx = 0
332    # Load and optionally reshape on the CPU, as string tensors are not
333    # available on the GPU.
334    # TODO(touts): Re-enable restore on GPU when we can support annotating
335    # string tensors as "HostMemory" inputs.
336    for saveable in saveables:
337      shapes = None
338      if reshape:
339        # Compute the shapes, let the restore op decide if and how to do
340        # the reshape.
341        shapes = []
342        for spec in saveable.specs:
343          v = spec.tensor
344          shape = v.get_shape()
345          if not shape.is_fully_defined():
346            shape = array_ops.shape(v)
347          shapes.append(shape)
348      saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
349      idx += len(saveable.specs)
350      assign_ops.append(saveable.restore(saveable_tensors, shapes))
351
352    # Create a Noop that has control dependencies from all the updates.
353    return control_flow_ops.group(*assign_ops, name=name)
354
355  def _AddShardedRestoreOps(self, filename_tensor, per_device,
356                            restore_sequentially, reshape):
357    """Add Ops to restore variables from multiple devices.
358
359    Args:
360      filename_tensor: Tensor for the path of the file to load.
361      per_device: A list of (device, SaveableObject) pairs, as returned by
362        _GroupByDevices().
363      restore_sequentially: True if we want to restore variables sequentially
364        within a shard.
365      reshape: True if we want to reshape loaded tensors to the shape of the
366        corresponding variable.
367
368    Returns:
369      An Operation that restores the variables.
370    """
371    sharded_restores = []
372    for shard, (device, saveables) in enumerate(per_device):
373      with ops.device(device):
374        sharded_restores.append(
375            self._AddRestoreOps(
376                filename_tensor,
377                saveables,
378                restore_sequentially,
379                reshape,
380                preferred_shard=shard,
381                name="restore_shard"))
382    return control_flow_ops.group(*sharded_restores, name="restore_all")
383
384  def _GroupByDevices(self, saveables):
385    """Group Variable tensor slices per device.
386
387    TODO(touts): Make sure that all the devices found are on different
388    job/replica/task/cpu|gpu.  It would be bad if 2 were on the same device.
389    It can happen if the devices are unspecified.
390
391    Args:
392      saveables: A list of BaseSaverBuilder.SaveableObject objects.
393
394    Returns:
395      A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
396      The list is sorted by ascending device_name.
397
398    Raises:
399      ValueError: If the tensors of a saveable are on different devices.
400    """
401    per_device = collections.defaultdict(lambda: [])
402    for saveable in saveables:
403      canonical_device = set(
404          pydev.canonical_name(spec.device) for spec in saveable.specs)
405      if len(canonical_device) != 1:
406        raise ValueError("All tensors of a saveable object must be "
407                         "on the same device: %s" % saveable.name)
408      per_device[canonical_device.pop()].append(saveable)
409    return sorted(per_device.items(), key=lambda t: t[0])
410
411  def build(self,
412            names_to_saveables,
413            reshape=False,
414            sharded=False,
415            max_to_keep=5,
416            keep_checkpoint_every_n_hours=10000.0,
417            name=None,
418            restore_sequentially=False,
419            filename="model"):
420    """Builds save/restore graph nodes or runs save/restore in eager mode.
421
422    Args:
423      names_to_saveables: A dictionary mapping name to a Variable or
424        SaveableObject. Each name will be associated with the corresponding
425        variable in the checkpoint.
426      reshape: If True, allow restoring parameters from a checkpoint that where
427        the parameters have a different shape.  This is only needed when you try
428        to restore from a Dist-Belief checkpoint, and only some times.
429      sharded: If True, shard the checkpoints, one per device that has Variable
430        nodes.
431      max_to_keep: Maximum number of checkpoints to keep.  As new checkpoints
432        are created, old ones are deleted.  If None or 0, no checkpoints are
433        deleted from the filesystem but only the last one is kept in the
434        `checkpoint` file.  Presently the number is only roughly enforced.  For
435        example in case of restarts more than max_to_keep checkpoints may be
436        kept.
437      keep_checkpoint_every_n_hours: How often checkpoints should be kept.
438        Defaults to 10,000 hours.
439      name: String.  Optional name to use as a prefix when adding operations.
440      restore_sequentially: A Bool, which if true, causes restore of different
441        variables to happen sequentially within each device.
442      filename: If known at graph construction time, filename used for variable
443        loading/saving. If None, then the default name "model" will be used.
444
445    Returns:
446      A SaverDef proto.
447
448    Raises:
449      TypeError: If 'names_to_saveables' is not a dictionary mapping string
450        keys to variable Tensors.
451      ValueError: If any of the keys or values in 'names_to_saveables' is not
452        unique.
453    """
454    return self._build_internal(
455        names_to_saveables=names_to_saveables,
456        reshape=reshape,
457        sharded=sharded,
458        max_to_keep=max_to_keep,
459        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
460        name=name,
461        restore_sequentially=restore_sequentially,
462        filename=filename)
463
464  def _build_internal(self,
465                      names_to_saveables,
466                      reshape=False,
467                      sharded=False,
468                      max_to_keep=5,
469                      keep_checkpoint_every_n_hours=10000.0,
470                      name=None,
471                      restore_sequentially=False,
472                      filename="model",
473                      build_save=True,
474                      build_restore=True):
475    """build() with option to only perform save and restore."""
476    if not context.executing_eagerly() and (not build_save or
477                                            not build_restore):
478      raise ValueError("save and restore operations need to be built together "
479                       " when eager execution is not enabled.")
480
481    saveables = saveable_object_util.validate_and_slice_inputs(
482        names_to_saveables)
483    if max_to_keep is None:
484      max_to_keep = 0
485
486    with ops.name_scope(name, "save",
487                        [saveable.op for saveable in saveables]) as name:
488      # Add a placeholder string tensor for the filename.
489      filename_tensor = array_ops.placeholder_with_default(
490          filename or "model", shape=(), name="filename")
491      # Keep the name "Const" for backwards compatibility.
492      filename_tensor = array_ops.placeholder_with_default(
493          filename_tensor, shape=(), name="Const")
494
495      # Add the save ops.
496      if sharded:
497        per_device = self._GroupByDevices(saveables)
498        if build_save:
499          save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
500        if build_restore:
501          restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
502                                                  restore_sequentially, reshape)
503      else:
504        if build_save:
505          save_tensor = self._AddSaveOps(filename_tensor, saveables)
506        if build_restore:
507          restore_op = self._AddRestoreOps(filename_tensor, saveables,
508                                           restore_sequentially, reshape)
509
510    # In the following use case, it's possible to have restore_ops be called
511    # something else:
512    # - Build inference graph and export a meta_graph.
513    # - Import the inference meta_graph
514    # - Extend the inference graph to a train graph.
515    # - Export a new meta_graph.
516    # Now the second restore_op will be called "restore_all_1".
517    # As such, comment out the assert for now until we know whether supporting
518    # such usage model makes sense.
519    #
520    # assert restore_op.name.endswith("restore_all"), restore_op.name
521    if context.executing_eagerly():
522      # Store the tensor values to the tensor_names.
523      save_tensor_name = save_tensor.numpy() if build_save else ""
524      return saver_pb2.SaverDef(
525          filename_tensor_name=filename_tensor.numpy(),
526          save_tensor_name=save_tensor_name,
527          restore_op_name="",
528          max_to_keep=max_to_keep,
529          sharded=sharded,
530          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
531          version=self._write_version)
532    else:
533      graph = ops.get_default_graph()
534      # Do some sanity checking on collections containing
535      # PartitionedVariables. If a saved collection has a PartitionedVariable,
536      # the GraphDef needs to include concat ops to get the value (or there'll
537      # be a lookup error on load).
538      check_collection_list = graph.get_all_collection_keys()
539      for collection_type in check_collection_list:
540        for element in graph.get_collection(collection_type):
541          if isinstance(element, variables.PartitionedVariable):
542            try:
543              graph.get_operation_by_name(element.name)
544            except KeyError:
545              # Create a concat op for this PartitionedVariable. The user may
546              # not need it, but we'll try looking it up on MetaGraph restore
547              # since it's in a collection.
548              element.as_tensor()
549      return saver_pb2.SaverDef(
550          filename_tensor_name=filename_tensor.name,
551          save_tensor_name=save_tensor.name,
552          restore_op_name=restore_op.name,
553          max_to_keep=max_to_keep,
554          sharded=sharded,
555          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
556          version=self._write_version)
557
558
559class BulkSaverBuilder(BaseSaverBuilder):
560  """SaverBuilder with support for bulk restoring multiple saveables."""
561
562  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
563                   restore_sequentially):
564
565    # Ignored: bulk restore is internally sequential.
566    del restore_sequentially
567    restore_specs = []
568    for saveable in saveables:
569      for spec in saveable.specs:
570        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
571
572    names, slices, dtypes = zip(*restore_specs)
573    # Load all tensors onto CPU 0 for compatibility with existing code.
574    with ops.device("cpu:0"):
575      return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
576
577
578def _get_saver_or_default():
579  """Returns the saver from SAVERS collection, or creates a default one.
580
581  This method is used by other members of the training module, such as
582  `Scaffold`, or `CheckpointSaverHook`.
583
584  Returns:
585    `Saver`.
586
587  Raises:
588    RuntimeError: If the SAVERS collection already has more than one items.
589  """
590  collection_key = ops.GraphKeys.SAVERS
591  savers = ops.get_collection(collection_key)
592  if savers:
593    if len(savers) > 1:
594      raise RuntimeError(
595          "More than one item in collection {}. "
596          "Please indicate which one to use by passing it to the constructor."
597          .format(collection_key))
598    return savers[0]
599  saver = Saver(sharded=True, allow_empty=True)
600  if saver is not None:
601    ops.add_to_collection(collection_key, saver)
602  return saver
603
604
605@tf_export(v1=["train.Saver"])
606class Saver(object):
607  """Saves and restores variables.
608
609  See [Variables](https://tensorflow.org/guide/variables)
610  for an overview of variables, saving and restoring.
611
612  The `Saver` class adds ops to save and restore variables to and from
613  *checkpoints*.  It also provides convenience methods to run these ops.
614
615  Checkpoints are binary files in a proprietary format which map variable names
616  to tensor values.  The best way to examine the contents of a checkpoint is to
617  load it using a `Saver`.
618
619  Savers can automatically number checkpoint filenames with a provided counter.
620  This lets you keep multiple checkpoints at different steps while training a
621  model.  For example you can number the checkpoint filenames with the training
622  step number.  To avoid filling up disks, savers manage checkpoint files
623  automatically. For example, they can keep only the N most recent files, or
624  one checkpoint for every N hours of training.
625
626  You number checkpoint filenames by passing a value to the optional
627  `global_step` argument to `save()`:
628
629  ```python
630  saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
631  ...
632  saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
633  ```
634
635  Additionally, optional arguments to the `Saver()` constructor let you control
636  the proliferation of checkpoint files on disk:
637
638  * `max_to_keep` indicates the maximum number of recent checkpoint files to
639    keep.  As new files are created, older files are deleted.   If None or 0,
640    no checkpoints are deleted from the filesystem but only the last one is
641    kept in the `checkpoint` file.  Defaults to 5 (that is, the 5 most recent
642    checkpoint files are kept.)
643
644  * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
645    `max_to_keep` checkpoint files, you might want to keep one checkpoint file
646    for every N hours of training.  This can be useful if you want to later
647    analyze how a model progressed during a long training session.  For
648    example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
649    one checkpoint file for every 2 hours of training.  The default value of
650    10,000 hours effectively disables the feature.
651
652  Note that you still have to call the `save()` method to save the model.
653  Passing these arguments to the constructor will not save variables
654  automatically for you.
655
656  A training program that saves regularly looks like:
657
658  ```python
659  ...
660  # Create a saver.
661  saver = tf.compat.v1.train.Saver(...variables...)
662  # Launch the graph and train, saving the model every 1,000 steps.
663  sess = tf.compat.v1.Session()
664  for step in xrange(1000000):
665      sess.run(..training_op..)
666      if step % 1000 == 0:
667          # Append the step number to the checkpoint name:
668          saver.save(sess, 'my-model', global_step=step)
669  ```
670
671  In addition to checkpoint files, savers keep a protocol buffer on disk with
672  the list of recent checkpoints. This is used to manage numbered checkpoint
673  files and by `latest_checkpoint()`, which makes it easy to discover the path
674  to the most recent checkpoint. That protocol buffer is stored in a file named
675  'checkpoint' next to the checkpoint files.
676
677  If you create several savers, you can specify a different filename for the
678  protocol buffer file in the call to `save()`.
679  """
680
681  def __init__(self,
682               var_list=None,
683               reshape=False,
684               sharded=False,
685               max_to_keep=5,
686               keep_checkpoint_every_n_hours=10000.0,
687               name=None,
688               restore_sequentially=False,
689               saver_def=None,
690               builder=None,
691               defer_build=False,
692               allow_empty=False,
693               write_version=saver_pb2.SaverDef.V2,
694               pad_step_number=False,
695               save_relative_paths=False,
696               filename=None):
697    """Creates a `Saver`.
698
699    The constructor adds ops to save and restore variables.
700
701    `var_list` specifies the variables that will be saved and restored. It can
702    be passed as a `dict` or a list:
703
704    * A `dict` of names to variables: The keys are the names that will be
705      used to save or restore the variables in the checkpoint files.
706    * A list of variables: The variables will be keyed with their op name in
707      the checkpoint files.
708
709    For example:
710
711    ```python
712    v1 = tf.Variable(..., name='v1')
713    v2 = tf.Variable(..., name='v2')
714
715    # Pass the variables as a dict:
716    saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
717
718    # Or pass them as a list.
719    saver = tf.compat.v1.train.Saver([v1, v2])
720    # Passing a list is equivalent to passing a dict with the variable op names
721    # as keys:
722    saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
723    ```
724
725    Note: the newer `AutoTrackable` API is not supported by `Saver`. In this
726    case, the `tf.train.Checkpoint` class should be used.
727
728    The optional `reshape` argument, if `True`, allows restoring a variable from
729    a save file where the variable had a different shape, but the same number
730    of elements and type.  This is useful if you have reshaped a variable and
731    want to reload it from an older checkpoint.
732
733    The optional `sharded` argument, if `True`, instructs the saver to shard
734    checkpoints per device.
735
736    Args:
737      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
738        names to `SaveableObject`s. If `None`, defaults to the list of all
739        saveable objects.
740      reshape: If `True`, allows restoring parameters from a checkpoint where
741        the variables have a different shape.
742      sharded: If `True`, shard the checkpoints, one per device.
743      max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
744      keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
745        10,000 hours.
746      name: String.  Optional name to use as a prefix when adding operations.
747      restore_sequentially: A `Bool`, which if true, causes restore of different
748        variables to happen sequentially within each device.  This can lower
749        memory usage when restoring very large models.
750      saver_def: Optional `SaverDef` proto to use instead of running the
751        builder. This is only useful for specialty code that wants to recreate a
752        `Saver` object for a previously built `Graph` that had a `Saver`. The
753        `saver_def` proto should be the one returned by the `as_saver_def()`
754        call of the `Saver` that was created for that `Graph`.
755      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
756        Defaults to `BulkSaverBuilder()`.
757      defer_build: If `True`, defer adding the save and restore ops to the
758        `build()` call. In that case `build()` should be called before
759        finalizing the graph or using the saver.
760      allow_empty: If `False` (default) raise an error if there are no variables
761        in the graph. Otherwise, construct the saver anyway and make it a no-op.
762      write_version: controls what format to use when saving checkpoints.  It
763        also affects certain filepath matching logic.  The V2 format is the
764        recommended choice: it is much more optimized than V1 in terms of memory
765          required and latency incurred during restore.  Regardless of this
766          flag, the Saver is able to restore from both V2 and V1 checkpoints.
767      pad_step_number: if True, pads the global step number in the checkpoint
768        filepaths to some fixed width (8 by default).  This is turned off by
769        default.
770      save_relative_paths: If `True`, will write relative paths to the
771        checkpoint state file. This is needed if the user wants to copy the
772        checkpoint directory and reload from the copied directory.
773      filename: If known at graph construction time, filename used for variable
774        loading/saving.
775
776    Raises:
777      TypeError: If `var_list` is invalid.
778      ValueError: If any of the keys or values in `var_list` are not unique.
779      RuntimeError: If eager execution is enabled and`var_list` does not specify
780        a list of variables to save.
781
782    @compatibility(eager)
783    When eager execution is enabled, `var_list` must specify a `list` or `dict`
784    of variables to save. Otherwise, a `RuntimeError` will be raised.
785
786    Although Saver works in some cases when executing eagerly, it is
787    fragile. Please switch to `tf.train.Checkpoint` or
788    `tf.keras.Model.save_weights`, which perform a more robust object-based
789    saving. These APIs will load checkpoints written by `Saver`.
790    @end_compatibility
791    """
792    if defer_build and var_list:
793      raise ValueError(
794          "If `var_list` is provided then build cannot be deferred. "
795          "Either set defer_build=False or var_list=None.")
796    if context.executing_eagerly():
797      logging.warning(
798          "Saver is deprecated, please switch to tf.train.Checkpoint or "
799          "tf.keras.Model.save_weights for training checkpoints. When "
800          "executing eagerly variables do not necessarily have unique names, "
801          "and so the variable.name-based lookups Saver performs are "
802          "error-prone.")
803      if var_list is None:
804        raise RuntimeError(
805            "When eager execution is enabled, `var_list` must specify a list "
806            "or dict of variables to save")
807    self._var_list = var_list
808    self._reshape = reshape
809    self._sharded = sharded
810    self._max_to_keep = max_to_keep
811    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
812    self._name = name
813    self._restore_sequentially = restore_sequentially
814    self.saver_def = saver_def
815    self._builder = builder
816    self._is_built = False
817    self._allow_empty = allow_empty
818    self._is_empty = None
819    self._write_version = write_version
820    self._pad_step_number = pad_step_number
821    self._filename = filename
822    self._last_checkpoints = []
823    self._checkpoints_to_be_deleted = []
824    if context.executing_eagerly():
825      self._next_checkpoint_time = (
826          time.time() + self._keep_checkpoint_every_n_hours * 3600)
827    elif not defer_build:
828      self.build()
829    if self.saver_def:
830      self._check_saver_def()
831      self._write_version = self.saver_def.version
832    self._save_relative_paths = save_relative_paths
833    # For compatibility with object-based checkpoints, we may build a second
834    # Saver to read the renamed keys.
835    self._object_restore_saver = None
836
837  def build(self):
838    if context.executing_eagerly():
839      raise RuntimeError("Use save/restore instead of build in eager mode.")
840    self._build(self._filename, build_save=True, build_restore=True)
841
842  def _build_eager(self, checkpoint_path, build_save, build_restore):
843    self._build(
844        checkpoint_path, build_save=build_save, build_restore=build_restore)
845
846  def _build(self, checkpoint_path, build_save, build_restore):
847    """Builds saver_def."""
848    if not context.executing_eagerly():
849      if self._is_built:
850        return
851      self._is_built = True
852
853    if not self.saver_def or context.executing_eagerly():
854      if self._builder is None:
855        self._builder = BulkSaverBuilder(self._write_version)
856
857      if self._var_list is None:
858        # pylint: disable=protected-access
859        self._var_list = variables._all_saveable_objects()
860      if not self._var_list:
861        if self._allow_empty:
862          self._is_empty = True
863          return
864        else:
865          raise ValueError("No variables to save")
866      self._is_empty = False
867
868      self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
869          self._var_list,
870          reshape=self._reshape,
871          sharded=self._sharded,
872          max_to_keep=self._max_to_keep,
873          keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
874          name=self._name,
875          restore_sequentially=self._restore_sequentially,
876          filename=checkpoint_path,
877          build_save=build_save,
878          build_restore=build_restore)
879    elif self.saver_def and self._name:
880      # Since self._name is used as a name_scope by builder(), we are
881      # overloading the use of this field to represent the "import_scope" as
882      # well.
883      self.saver_def.filename_tensor_name = ops.prepend_name_scope(
884          self.saver_def.filename_tensor_name, self._name)
885      self.saver_def.save_tensor_name = ops.prepend_name_scope(
886          self.saver_def.save_tensor_name, self._name)
887      self.saver_def.restore_op_name = ops.prepend_name_scope(
888          self.saver_def.restore_op_name, self._name)
889
890    self._check_saver_def()
891    if not context.executing_eagerly():
892      # Updates next checkpoint time.
893      # Set in __init__ when executing eagerly.
894      self._next_checkpoint_time = (
895          time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
896
897  def _check_saver_def(self):
898    if not isinstance(self.saver_def, saver_pb2.SaverDef):
899      raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
900                       self.saver_def)
901    if not context.executing_eagerly():
902      if not self.saver_def.save_tensor_name:
903        raise ValueError("saver_def must specify the save_tensor_name: %s" %
904                         str(self.saver_def))
905      if not self.saver_def.restore_op_name:
906        raise ValueError("saver_def must specify the restore_op_name: %s" %
907                         str(self.saver_def))
908
909  def _CheckpointFilename(self, p):
910    """Returns the checkpoint filename given a `(filename, time)` pair.
911
912    Args:
913      p: (filename, time) pair.
914
915    Returns:
916      Checkpoint file name.
917    """
918    name, _ = p
919    return name
920
921  def _RecordLastCheckpoint(self, latest_save_path):
922    """Manages the list of the latest checkpoints."""
923    if not self.saver_def.max_to_keep:
924      return
925    # Remove first from list if the same name was used before.
926    for p in self._last_checkpoints:
927      if latest_save_path == self._CheckpointFilename(p):
928        self._last_checkpoints.remove(p)
929    # Append new path to list
930    self._last_checkpoints.append((latest_save_path, time.time()))
931
932    # If more than max_to_keep, remove oldest.
933    if len(self._last_checkpoints) > self.saver_def.max_to_keep:
934      self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
935
936  def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
937    """Deletes old checkpoints if necessary.
938
939    `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
940    over `max_to_keep`.  They are going to be deleted.  If
941    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
942    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
943    kept for every 0.5 hours of training; if `N` is 10, an additional
944    checkpoint is kept for every 10 hours of training.
945
946    Args:
947      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
948    """
949    if self._checkpoints_to_be_deleted:
950      p = self._checkpoints_to_be_deleted.pop(0)
951      # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
952      # have reached N hours of training.
953      should_keep = p[1] > self._next_checkpoint_time
954      if should_keep:
955        self._next_checkpoint_time += (
956            self.saver_def.keep_checkpoint_every_n_hours * 3600)
957        return
958
959      # Otherwise delete the files.
960      try:
961        checkpoint_management.remove_checkpoint(
962            self._CheckpointFilename(p), self.saver_def.version,
963            meta_graph_suffix)
964      except Exception as e:  # pylint: disable=broad-except
965        logging.warning("Ignoring: %s", str(e))
966
967  def as_saver_def(self):
968    """Generates a `SaverDef` representation of this saver.
969
970    Returns:
971      A `SaverDef` proto.
972    """
973    return self.saver_def
974
975  def to_proto(self, export_scope=None):
976    """Converts this `Saver` to a `SaverDef` protocol buffer.
977
978    Args:
979      export_scope: Optional `string`. Name scope to remove.
980
981    Returns:
982      A `SaverDef` protocol buffer.
983    """
984    if export_scope is None:
985      return self.saver_def
986
987    if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
988            self.saver_def.save_tensor_name.startswith(export_scope) and
989            self.saver_def.restore_op_name.startswith(export_scope)):
990      return None
991
992    saver_def = saver_pb2.SaverDef()
993    saver_def.CopyFrom(self.saver_def)
994    saver_def.filename_tensor_name = ops.strip_name_scope(
995        saver_def.filename_tensor_name, export_scope)
996    saver_def.save_tensor_name = ops.strip_name_scope(
997        saver_def.save_tensor_name, export_scope)
998    saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
999                                                     export_scope)
1000    return saver_def
1001
1002  @staticmethod
1003  def from_proto(saver_def, import_scope=None):
1004    """Returns a `Saver` object created from `saver_def`.
1005
1006    Args:
1007      saver_def: a `SaverDef` protocol buffer.
1008      import_scope: Optional `string`. Name scope to use.
1009
1010    Returns:
1011      A `Saver` built from saver_def.
1012    """
1013    return Saver(saver_def=saver_def, name=import_scope)
1014
1015  @property
1016  def last_checkpoints(self):
1017    """List of not-yet-deleted checkpoint filenames.
1018
1019    You can pass any of the returned values to `restore()`.
1020
1021    Returns:
1022      A list of checkpoint filenames, sorted from oldest to newest.
1023    """
1024    return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
1025
1026  def set_last_checkpoints(self, last_checkpoints):
1027    """DEPRECATED: Use set_last_checkpoints_with_time.
1028
1029    Sets the list of old checkpoint filenames.
1030
1031    Args:
1032      last_checkpoints: A list of checkpoint filenames.
1033
1034    Raises:
1035      AssertionError: If last_checkpoints is not a list.
1036    """
1037    assert isinstance(last_checkpoints, list)
1038    # We use a timestamp of +inf so that this checkpoint will never be
1039    # deleted.  This is both safe and backwards compatible to a previous
1040    # version of the code which used s[1] as the "timestamp".
1041    self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
1042
1043  def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
1044    """Sets the list of old checkpoint filenames and timestamps.
1045
1046    Args:
1047      last_checkpoints_with_time: A list of tuples of checkpoint filenames and
1048        timestamps.
1049
1050    Raises:
1051      AssertionError: If last_checkpoints_with_time is not a list.
1052    """
1053    assert isinstance(last_checkpoints_with_time, list)
1054    self._last_checkpoints = last_checkpoints_with_time
1055
1056  def recover_last_checkpoints(self, checkpoint_paths):
1057    """Recovers the internal saver state after a crash.
1058
1059    This method is useful for recovering the "self._last_checkpoints" state.
1060
1061    Globs for the checkpoints pointed to by `checkpoint_paths`.  If the files
1062    exist, use their mtime as the checkpoint timestamp.
1063
1064    Args:
1065      checkpoint_paths: a list of checkpoint paths.
1066    """
1067    checkpoints_with_mtimes = []
1068    for checkpoint_path in checkpoint_paths:
1069      mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
1070      if mtime:
1071        checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
1072    self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
1073
1074  def save(self,
1075           sess,
1076           save_path,
1077           global_step=None,
1078           latest_filename=None,
1079           meta_graph_suffix="meta",
1080           write_meta_graph=True,
1081           write_state=True,
1082           strip_default_attrs=False,
1083           save_debug_info=False):
1084    # pylint: disable=line-too-long
1085    """Saves variables.
1086
1087    This method runs the ops added by the constructor for saving variables.
1088    It requires a session in which the graph was launched.  The variables to
1089    save must also have been initialized.
1090
1091    The method returns the path prefix of the newly created checkpoint files.
1092    This string can be passed directly to a call to `restore()`.
1093
1094    Args:
1095      sess: A Session to use to save the variables.
1096      save_path: String.  Prefix of filenames created for the checkpoint.
1097      global_step: If provided the global step number is appended to `save_path`
1098        to create the checkpoint filenames. The optional argument can be a
1099        `Tensor`, a `Tensor` name or an integer.
1100      latest_filename: Optional name for the protocol buffer file that will
1101        contains the list of most recent checkpoints.  That file, kept in the
1102        same directory as the checkpoint files, is automatically managed by the
1103        saver to keep track of recent checkpoints.  Defaults to 'checkpoint'.
1104      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1105      write_meta_graph: `Boolean` indicating whether or not to write the meta
1106        graph file.
1107      write_state: `Boolean` indicating whether or not to write the
1108        `CheckpointStateProto`.
1109      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1110        removed from the NodeDefs. For a detailed guide, see
1111        [Stripping Default-Valued
1112          Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1113      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1114        which in the same directory of save_path and with `_debug` added before
1115        the file extension. This is only enabled when `write_meta_graph` is
1116        `True`
1117
1118    Returns:
1119      A string: path prefix used for the checkpoint files.  If the saver is
1120        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
1121        is the number of shards created.
1122      If the saver is empty, returns None.
1123
1124    Raises:
1125      TypeError: If `sess` is not a `Session`.
1126      ValueError: If `latest_filename` contains path components, or if it
1127        collides with `save_path`.
1128      RuntimeError: If save and restore ops weren't built.
1129    """
1130    # pylint: enable=line-too-long
1131    if not self._is_built and not context.executing_eagerly():
1132      raise RuntimeError(
1133          "`build()` should be called before save if defer_build==True")
1134    if latest_filename is None:
1135      latest_filename = "checkpoint"
1136    if self._write_version != saver_pb2.SaverDef.V2:
1137      logging.warning("*******************************************************")
1138      logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
1139      logging.warning("Consider switching to the more efficient V2 format:")
1140      logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
1141      logging.warning("now on by default.")
1142      logging.warning("*******************************************************")
1143
1144    if os.path.split(latest_filename)[0]:
1145      raise ValueError("'latest_filename' must not contain path components")
1146
1147    if global_step is not None:
1148      if not isinstance(global_step, compat.integral_types):
1149        global_step = training_util.global_step(sess, global_step)
1150      checkpoint_file = "%s-%d" % (save_path, global_step)
1151      if self._pad_step_number:
1152        # Zero-pads the step numbers, so that they are sorted when listed.
1153        checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
1154    else:
1155      checkpoint_file = save_path
1156      if os.path.basename(save_path) == latest_filename and not self._sharded:
1157        # Guard against collision between data file and checkpoint state file.
1158        raise ValueError(
1159            "'latest_filename' collides with 'save_path': '%s' and '%s'" %
1160            (latest_filename, save_path))
1161
1162    if (not context.executing_eagerly() and
1163        not isinstance(sess, session.SessionInterface)):
1164      raise TypeError("'sess' must be a Session; %s" % sess)
1165
1166    save_path_parent = os.path.dirname(save_path)
1167    if not self._is_empty:
1168      try:
1169        if context.executing_eagerly():
1170          self._build_eager(
1171              checkpoint_file, build_save=True, build_restore=False)
1172          model_checkpoint_path = self.saver_def.save_tensor_name
1173        else:
1174          model_checkpoint_path = sess.run(
1175              self.saver_def.save_tensor_name,
1176              {self.saver_def.filename_tensor_name: checkpoint_file})
1177
1178        model_checkpoint_path = compat.as_str(model_checkpoint_path)
1179        if write_state:
1180          self._RecordLastCheckpoint(model_checkpoint_path)
1181          checkpoint_management.update_checkpoint_state_internal(
1182              save_dir=save_path_parent,
1183              model_checkpoint_path=model_checkpoint_path,
1184              all_model_checkpoint_paths=self.last_checkpoints,
1185              latest_filename=latest_filename,
1186              save_relative_paths=self._save_relative_paths)
1187          self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
1188      except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
1189        if not gfile.IsDirectory(save_path_parent):
1190          exc = ValueError(
1191              "Parent directory of {} doesn't exist, can't save.".format(
1192                  save_path))
1193        raise exc
1194
1195    if write_meta_graph:
1196      meta_graph_filename = checkpoint_management.meta_graph_filename(
1197          checkpoint_file, meta_graph_suffix=meta_graph_suffix)
1198      if not context.executing_eagerly():
1199        with sess.graph.as_default():
1200          self.export_meta_graph(
1201              meta_graph_filename,
1202              strip_default_attrs=strip_default_attrs,
1203              save_debug_info=save_debug_info)
1204
1205    if self._is_empty:
1206      return None
1207    else:
1208      return model_checkpoint_path
1209
1210  def export_meta_graph(self,
1211                        filename=None,
1212                        collection_list=None,
1213                        as_text=False,
1214                        export_scope=None,
1215                        clear_devices=False,
1216                        clear_extraneous_savers=False,
1217                        strip_default_attrs=False,
1218                        save_debug_info=False):
1219    # pylint: disable=line-too-long
1220    """Writes `MetaGraphDef` to save_path/filename.
1221
1222    Args:
1223      filename: Optional meta_graph filename including the path.
1224      collection_list: List of string keys to collect.
1225      as_text: If `True`, writes the meta_graph as an ASCII proto.
1226      export_scope: Optional `string`. Name scope to remove.
1227      clear_devices: Whether or not to clear the device field for an `Operation`
1228        or `Tensor` during export.
1229      clear_extraneous_savers: Remove any Saver-related information from the
1230        graph (both Save/Restore ops and SaverDefs) that are not associated with
1231        this Saver.
1232      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1233        removed from the NodeDefs. For a detailed guide, see
1234        [Stripping Default-Valued
1235          Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1236      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1237        which in the same directory of filename and with `_debug` added before
1238        the file extension.
1239
1240    Returns:
1241      A `MetaGraphDef` proto.
1242    """
1243    # pylint: enable=line-too-long
1244    return export_meta_graph(
1245        filename=filename,
1246        graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
1247        saver_def=self.saver_def,
1248        collection_list=collection_list,
1249        as_text=as_text,
1250        export_scope=export_scope,
1251        clear_devices=clear_devices,
1252        clear_extraneous_savers=clear_extraneous_savers,
1253        strip_default_attrs=strip_default_attrs,
1254        save_debug_info=save_debug_info)
1255
1256  def restore(self, sess, save_path):
1257    """Restores previously saved variables.
1258
1259    This method runs the ops added by the constructor for restoring variables.
1260    It requires a session in which the graph was launched.  The variables to
1261    restore do not have to have been initialized, as restoring is itself a way
1262    to initialize variables.
1263
1264    The `save_path` argument is typically a value previously returned from a
1265    `save()` call, or a call to `latest_checkpoint()`.
1266
1267    Args:
1268      sess: A `Session` to use to restore the parameters. None in eager mode.
1269      save_path: Path where parameters were previously saved.
1270
1271    Raises:
1272      ValueError: If save_path is None or not a valid checkpoint.
1273    """
1274    if self._is_empty:
1275      return
1276    if save_path is None:
1277      raise ValueError("Can't load save_path when it is None.")
1278
1279    checkpoint_prefix = compat.as_text(save_path)
1280    if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix):
1281      raise ValueError("The passed save_path is not a valid checkpoint: " +
1282                       checkpoint_prefix)
1283
1284    logging.info("Restoring parameters from %s", checkpoint_prefix)
1285    try:
1286      if context.executing_eagerly():
1287        self._build_eager(save_path, build_save=False, build_restore=True)
1288      else:
1289        sess.run(self.saver_def.restore_op_name,
1290                 {self.saver_def.filename_tensor_name: save_path})
1291    except errors.NotFoundError as err:
1292      # There are three common conditions that might cause this error:
1293      # 0. The file is missing. We ignore here, as this is checked above.
1294      # 1. This is an object-based checkpoint trying name-based loading.
1295      # 2. The graph has been altered and a variable or other name is missing.
1296
1297      # 1. The checkpoint would not be loaded successfully as is. Try to parse
1298      # it as an object-based checkpoint.
1299      try:
1300        names_to_keys = object_graph_key_mapping(save_path)
1301      except errors.NotFoundError:
1302        # 2. This is not an object-based checkpoint, which likely means there
1303        # is a graph mismatch. Re-raise the original error with
1304        # a helpful message (b/110263146)
1305        raise _wrap_restore_error_with_msg(
1306            err, "a Variable name or other graph key that is missing")
1307
1308      # This is an object-based checkpoint. We'll print a warning and then do
1309      # the restore.
1310      logging.warning(
1311          "Restoring an object-based checkpoint using a name-based saver. This "
1312          "may be somewhat fragile, and will re-build the Saver. Instead, "
1313          "consider loading object-based checkpoints using "
1314          "tf.train.Checkpoint().")
1315      self._object_restore_saver = saver_from_object_based_checkpoint(
1316          checkpoint_path=save_path,
1317          var_list=self._var_list,
1318          builder=self._builder,
1319          names_to_keys=names_to_keys,
1320          cached_saver=self._object_restore_saver)
1321      self._object_restore_saver.restore(sess=sess, save_path=save_path)
1322    except errors.InvalidArgumentError as err:
1323      # There is a mismatch between the graph and the checkpoint being loaded.
1324      # We add a more reasonable error message here to help users (b/110263146)
1325      raise _wrap_restore_error_with_msg(
1326          err, "a mismatch between the current graph and the graph")
1327
1328  @staticmethod
1329  def _add_collection_def(meta_graph_def, key, export_scope=None):
1330    """Adds a collection to MetaGraphDef protocol buffer.
1331
1332    Args:
1333      meta_graph_def: MetaGraphDef protocol buffer.
1334      key: One of the GraphKeys or user-defined string.
1335      export_scope: Optional `string`. Name scope to remove.
1336    """
1337    meta_graph.add_collection_def(
1338        meta_graph_def, key, export_scope=export_scope)
1339
1340
1341@tf_export(v1=["train.import_meta_graph"])
1342def import_meta_graph(meta_graph_or_file,
1343                      clear_devices=False,
1344                      import_scope=None,
1345                      **kwargs):
1346  """Recreates a Graph saved in a `MetaGraphDef` proto.
1347
1348  This function takes a `MetaGraphDef` protocol buffer as input. If
1349  the argument is a file containing a `MetaGraphDef` protocol buffer ,
1350  it constructs a protocol buffer from the file content. The function
1351  then adds all the nodes from the `graph_def` field to the
1352  current graph, recreates all the collections, and returns a saver
1353  constructed from the `saver_def` field.
1354
1355  In combination with `export_meta_graph()`, this function can be used to
1356
1357  * Serialize a graph along with other Python objects such as `QueueRunner`,
1358    `Variable` into a `MetaGraphDef`.
1359
1360  * Restart training from a saved graph and checkpoints.
1361
1362  * Run inference from a saved graph and checkpoints.
1363
1364  ```Python
1365  ...
1366  # Create a saver.
1367  saver = tf.compat.v1.train.Saver(...variables...)
1368  # Remember the training_op we want to run by adding it to a collection.
1369  tf.compat.v1.add_to_collection('train_op', train_op)
1370  sess = tf.compat.v1.Session()
1371  for step in xrange(1000000):
1372      sess.run(train_op)
1373      if step % 1000 == 0:
1374          # Saves checkpoint, which by default also exports a meta_graph
1375          # named 'my-model-global_step.meta'.
1376          saver.save(sess, 'my-model', global_step=step)
1377  ```
1378
1379  Later we can continue training from this saved `meta_graph` without building
1380  the model from scratch.
1381
1382  ```Python
1383  with tf.Session() as sess:
1384    new_saver =
1385    tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
1386    new_saver.restore(sess, 'my-save-dir/my-model-10000')
1387    # tf.get_collection() returns a list. In this example we only want
1388    # the first one.
1389    train_op = tf.get_collection('train_op')[0]
1390    for step in xrange(1000000):
1391      sess.run(train_op)
1392  ```
1393
1394  NOTE: Restarting training from saved `meta_graph` only works if the
1395  device assignments have not changed.
1396
1397  Example:
1398  Variables, placeholders, and independent operations can also be stored, as
1399  shown in the following example.
1400
1401  ```Python
1402  # Saving contents and operations.
1403  v1 = tf.placeholder(tf.float32, name="v1")
1404  v2 = tf.placeholder(tf.float32, name="v2")
1405  v3 = tf.math.multiply(v1, v2)
1406  vx = tf.Variable(10.0, name="vx")
1407  v4 = tf.add(v3, vx, name="v4")
1408  saver = tf.train.Saver([vx])
1409  sess = tf.Session()
1410  sess.run(tf.global_variables_initializer())
1411  sess.run(vx.assign(tf.add(vx, vx)))
1412  result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
1413  print(result)
1414  saver.save(sess, "./model_ex1")
1415  ```
1416
1417  Later this model can be restored and contents loaded.
1418
1419  ```Python
1420  # Restoring variables and running operations.
1421  saver = tf.train.import_meta_graph("./model_ex1.meta")
1422  sess = tf.Session()
1423  saver.restore(sess, "./model_ex1")
1424  result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
1425  print(result)
1426  ```
1427
1428  Args:
1429    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
1430      the path) containing a `MetaGraphDef`.
1431    clear_devices: Whether or not to clear the device field for an `Operation`
1432      or `Tensor` during import.
1433    import_scope: Optional `string`. Name scope to add. Only used when
1434      initializing from protocol buffer.
1435    **kwargs: Optional keyed arguments.
1436
1437  Returns:
1438    A saver constructed from `saver_def` in `MetaGraphDef` or None.
1439
1440    A None value is returned if no variables exist in the `MetaGraphDef`
1441    (i.e., there are no variables to restore).
1442
1443  Raises:
1444    RuntimeError: If called with eager execution enabled.
1445
1446  @compatibility(eager)
1447  Exporting/importing meta graphs is not supported. No graph exists when eager
1448  execution is enabled.
1449  @end_compatibility
1450  """  # pylint: disable=g-doc-exception
1451  return _import_meta_graph_with_return_elements(meta_graph_or_file,
1452                                                 clear_devices, import_scope,
1453                                                 **kwargs)[0]
1454
1455
1456def _import_meta_graph_with_return_elements(meta_graph_or_file,
1457                                            clear_devices=False,
1458                                            import_scope=None,
1459                                            return_elements=None,
1460                                            **kwargs):
1461  """Import MetaGraph, and return both a saver and returned elements."""
1462  if context.executing_eagerly():
1463    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1464                       "eager execution is enabled. No graph exists when eager "
1465                       "execution is enabled.")
1466  if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
1467    meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
1468  else:
1469    meta_graph_def = meta_graph_or_file
1470
1471  imported_vars, imported_return_elements = (
1472      meta_graph.import_scoped_meta_graph_with_return_elements(
1473          meta_graph_def,
1474          clear_devices=clear_devices,
1475          import_scope=import_scope,
1476          return_elements=return_elements,
1477          **kwargs))
1478
1479  saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1480                                                 imported_vars)
1481  return saver, imported_return_elements
1482
1483
1484def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1485                                           imported_vars):
1486  """Return a saver for restoring variable values to an imported MetaGraph."""
1487  if meta_graph_def.HasField("saver_def"):
1488    # Infer the scope that is prepended by `import_scoped_meta_graph`.
1489    scope = import_scope
1490    var_names = list(imported_vars.keys())
1491    if var_names:
1492      sample_key = var_names[0]
1493      sample_var = imported_vars[sample_key]
1494      scope = sample_var.name[:-len(sample_key)]
1495
1496    return Saver(saver_def=meta_graph_def.saver_def, name=scope)
1497  else:
1498    if variables._all_saveable_objects(scope=import_scope):  # pylint: disable=protected-access
1499      # Return the default saver instance for all graph variables.
1500      return Saver()
1501    else:
1502      # If no graph variables exist, then a Saver cannot be constructed.
1503      logging.info("Saver not created because there are no variables in the"
1504                   " graph to restore")
1505      return None
1506
1507
1508@tf_export(v1=["train.export_meta_graph"])
1509def export_meta_graph(filename=None,
1510                      meta_info_def=None,
1511                      graph_def=None,
1512                      saver_def=None,
1513                      collection_list=None,
1514                      as_text=False,
1515                      graph=None,
1516                      export_scope=None,
1517                      clear_devices=False,
1518                      clear_extraneous_savers=False,
1519                      strip_default_attrs=False,
1520                      save_debug_info=False,
1521                      **kwargs):
1522  # pylint: disable=line-too-long
1523  """Returns `MetaGraphDef` proto.
1524
1525  Optionally writes it to filename.
1526
1527  This function exports the graph, saver, and collection objects into
1528  `MetaGraphDef` protocol buffer with the intention of it being imported
1529  at a later time or location to restart training, run inference, or be
1530  a subgraph.
1531
1532  Args:
1533    filename: Optional filename including the path for writing the generated
1534      `MetaGraphDef` protocol buffer.
1535    meta_info_def: `MetaInfoDef` protocol buffer.
1536    graph_def: `GraphDef` protocol buffer.
1537    saver_def: `SaverDef` protocol buffer.
1538    collection_list: List of string keys to collect.
1539    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
1540    graph: The `Graph` to export. If `None`, use the default graph.
1541    export_scope: Optional `string`. Name scope under which to extract the
1542      subgraph. The scope name will be striped from the node definitions for
1543      easy import later into new name scopes. If `None`, the whole graph is
1544      exported. graph_def and export_scope cannot both be specified.
1545    clear_devices: Whether or not to clear the device field for an `Operation`
1546      or `Tensor` during export.
1547    clear_extraneous_savers: Remove any Saver-related information from the graph
1548      (both Save/Restore ops and SaverDefs) that are not associated with the
1549      provided SaverDef.
1550    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1551      removed from the NodeDefs. For a detailed guide, see
1552      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1553    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1554      which in the same directory of filename and with `_debug` added before the
1555      file extend.
1556    **kwargs: Optional keyed arguments.
1557
1558  Returns:
1559    A `MetaGraphDef` proto.
1560
1561  Raises:
1562    ValueError: When the `GraphDef` is larger than 2GB.
1563    RuntimeError: If called with eager execution enabled.
1564
1565  @compatibility(eager)
1566  Exporting/importing meta graphs is not supported unless both `graph_def` and
1567  `graph` are provided. No graph exists when eager execution is enabled.
1568  @end_compatibility
1569  """
1570  # pylint: enable=line-too-long
1571  if context.executing_eagerly() and not (graph_def is not None and
1572                                          graph is not None):
1573    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1574                       "eager execution is enabled. No graph exists when eager "
1575                       "execution is enabled.")
1576  meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
1577      filename=filename,
1578      meta_info_def=meta_info_def,
1579      graph_def=graph_def,
1580      saver_def=saver_def,
1581      collection_list=collection_list,
1582      as_text=as_text,
1583      graph=graph,
1584      export_scope=export_scope,
1585      clear_devices=clear_devices,
1586      clear_extraneous_savers=clear_extraneous_savers,
1587      strip_default_attrs=strip_default_attrs,
1588      save_debug_info=save_debug_info,
1589      **kwargs)
1590  return meta_graph_def
1591
1592
1593def _wrap_restore_error_with_msg(err, extra_verbiage):
1594  err_msg = ("Restoring from checkpoint failed. This is most likely "
1595             "due to {} from the checkpoint. Please ensure that you "
1596             "have not altered the graph expected based on the checkpoint. "
1597             "Original error:\n\n{}").format(extra_verbiage, err.message)
1598  return err.__class__(err.node_def, err.op, err_msg)
1599
1600
1601ops.register_proto_function(
1602    ops.GraphKeys.SAVERS,
1603    proto_type=saver_pb2.SaverDef,
1604    to_proto=Saver.to_proto,
1605    from_proto=Saver.from_proto)
1606
1607
1608def object_graph_key_mapping(checkpoint_path):
1609  """Return name to key mappings from the checkpoint.
1610
1611  Args:
1612    checkpoint_path: string, path to object-based checkpoint
1613
1614  Returns:
1615    Dictionary mapping tensor names to checkpoint keys.
1616  """
1617  reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
1618  object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
1619  object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1620  object_graph_proto.ParseFromString(object_graph_string)
1621  names_to_keys = {}
1622  for node in object_graph_proto.nodes:
1623    for attribute in node.attributes:
1624      names_to_keys[attribute.full_name] = attribute.checkpoint_key
1625  return names_to_keys
1626
1627
1628def saver_from_object_based_checkpoint(checkpoint_path,
1629                                       var_list=None,
1630                                       builder=None,
1631                                       names_to_keys=None,
1632                                       cached_saver=None):
1633  """Return a `Saver` which reads from an object-based checkpoint.
1634
1635  This function validates that all variables in the variables list are remapped
1636  in the object-based checkpoint (or `names_to_keys` dict if provided). A
1637  saver will be created with the list of remapped variables.
1638
1639  The `cached_saver` argument allows the user to pass in a previously created
1640  saver, so multiple `saver.restore()` calls don't pollute the graph when graph
1641  building. This assumes that keys are consistent, meaning that the
1642    1) `checkpoint_path` checkpoint, and
1643    2) checkpoint used to create the `cached_saver`
1644  are the same type of object-based checkpoint. If this argument is set, this
1645  function will simply validate that all variables have been remapped by the
1646  checkpoint at `checkpoint_path`.
1647
1648  Note that in general, `tf.train.Checkpoint` should be used to restore/save an
1649  object-based checkpoint.
1650
1651  Args:
1652    checkpoint_path: string, path to object-based checkpoint
1653    var_list: list of `Variables` that appear in the checkpoint. If `None`,
1654      `var_list` will be set to all saveable objects.
1655    builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder`
1656      will be created.
1657    names_to_keys: dict mapping string tensor names to checkpooint keys. If
1658      `None`, this dict will be generated from the checkpoint file.
1659    cached_saver: Cached `Saver` object with remapped variables.
1660
1661  Returns:
1662    `Saver` with remapped variables for reading from an object-based checkpoint.
1663
1664  Raises:
1665    ValueError if the checkpoint provided is not an object-based checkpoint.
1666    NotFoundError: If one of the variables in `var_list` can not be found in the
1667      checkpoint. This could mean the checkpoint or `names_to_keys` mapping is
1668      missing the variable.
1669  """
1670  if names_to_keys is None:
1671    try:
1672      names_to_keys = object_graph_key_mapping(checkpoint_path)
1673    except errors.NotFoundError:
1674      raise ValueError("Checkpoint in %s not an object-based checkpoint." %
1675                       checkpoint_path)
1676  if var_list is None:
1677    var_list = variables._all_saveable_objects()  # pylint: disable=protected-access
1678  if builder is None:
1679    builder = BulkSaverBuilder()
1680
1681  saveables = saveable_object_util.validate_and_slice_inputs(var_list)
1682  current_names = set()
1683  for saveable in saveables:
1684    for spec in saveable.specs:
1685      current_names.add(spec.name)
1686  previous_names = set(names_to_keys.keys())
1687  missing_names = current_names - previous_names
1688  if missing_names:
1689    extra_names = previous_names - current_names
1690    intersecting_names = previous_names.intersection(current_names)
1691    raise errors.NotFoundError(
1692        None,
1693        None,
1694        message=(
1695            "\n\nExisting variables not in the checkpoint: %s\n\n"
1696            "Variables names when this checkpoint was written which don't "
1697            "exist now: %s\n\n"
1698            "(%d variable name(s) did match)\n\n"
1699            "Could not find some variables in the checkpoint (see names "
1700            "above). Saver was attempting to load an object-based checkpoint "
1701            "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) "
1702            "using variable names. If the checkpoint was written with eager "
1703            "execution enabled, it's possible that variable names have "
1704            "changed (for example missing a '_1' suffix). It's also "
1705            "possible that there are new variables which did not exist "
1706            "when the checkpoint was written. You can construct a "
1707            "Saver(var_list=...) with only the variables which previously "
1708            "existed, and if variable names have changed you may need to "
1709            "make this a dictionary with the old names as keys. If you're "
1710            "using an Estimator, you'll need to return a tf.train.Saver "
1711            "inside a tf.train.Scaffold from your model_fn.") %
1712        (", ".join(sorted(missing_names)), ", ".join(
1713            sorted(extra_names)), len(intersecting_names)))
1714  for saveable in saveables:
1715    for spec in saveable.specs:
1716      spec.name = names_to_keys[spec.name]
1717  if cached_saver is None:
1718    return Saver(saveables)
1719  return cached_saver
1720