• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables.
18
19Symbols in this file are deprecated. See replacements in
20tensorflow/python/training/trackable and tensorflow/python/training/saving.
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27import os.path
28import time
29
30import numpy as np
31from tensorflow.core.protobuf import meta_graph_pb2
32from tensorflow.core.protobuf import saver_pb2
33from tensorflow.core.protobuf import trackable_object_graph_pb2
34from tensorflow.python.client import session
35from tensorflow.python.eager import context
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import device as pydev
38from tensorflow.python.framework import errors
39from tensorflow.python.framework import meta_graph
40from tensorflow.python.framework import ops
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import control_flow_ops
43from tensorflow.python.ops import gen_io_ops
44from tensorflow.python.ops import io_ops
45from tensorflow.python.ops import string_ops
46from tensorflow.python.ops import variables
47from tensorflow.python.platform import gfile
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.training import checkpoint_management
50from tensorflow.python.training import py_checkpoint_reader
51from tensorflow.python.training import training_util
52from tensorflow.python.training.saving import saveable_object
53from tensorflow.python.training.saving import saveable_object_util
54from tensorflow.python.training.tracking import base as trackable
55from tensorflow.python.util import compat
56from tensorflow.python.util.tf_export import tf_export
57
58# TODO(allenl): Remove these aliases once all users are migrated off.
59get_checkpoint_state = checkpoint_management.get_checkpoint_state
60update_checkpoint_state = checkpoint_management.update_checkpoint_state
61generate_checkpoint_state_proto = (
62    checkpoint_management.generate_checkpoint_state_proto)
63latest_checkpoint = checkpoint_management.latest_checkpoint
64checkpoint_exists = checkpoint_management.checkpoint_exists
65get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
66remove_checkpoint = checkpoint_management.remove_checkpoint
67
68
69class BaseSaverBuilder(object):
70  """Base class for Savers.
71
72  Can be extended to create different Ops.
73  """
74
75  SaveSpec = saveable_object.SaveSpec
76  SaveableObject = saveable_object.SaveableObject
77
78  # Aliases for code which was moved but still has lots of users.
79  VariableSaveable = saveable_object_util.ReferenceVariableSaveable
80  ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
81
82  def __init__(self, write_version=saver_pb2.SaverDef.V2):
83    self._write_version = write_version
84
85  def save_op(self, filename_tensor, saveables):
86    """Create an Op to save 'saveables'.
87
88    This is intended to be overridden by subclasses that want to generate
89    different Ops.
90
91    Args:
92      filename_tensor: String Tensor.
93      saveables: A list of BaseSaverBuilder.SaveableObject objects.
94
95    Returns:
96      An Operation that save the variables.
97
98    Raises:
99      RuntimeError: (implementation detail) if "self._write_version" is an
100        unexpected value.
101    """
102    # pylint: disable=protected-access
103    tensor_names = []
104    tensors = []
105    tensor_slices = []
106    for saveable in saveables:
107      for spec in saveable.specs:
108        tensor_names.append(spec.name)
109        tensors.append(spec.tensor)
110        tensor_slices.append(spec.slice_spec)
111    if self._write_version == saver_pb2.SaverDef.V1:
112      return io_ops._save(
113          filename=filename_tensor,
114          tensor_names=tensor_names,
115          tensors=tensors,
116          tensor_slices=tensor_slices)
117    elif self._write_version == saver_pb2.SaverDef.V2:
118      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
119      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
120      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
121                            tensors)
122    else:
123      raise RuntimeError("Unexpected write_version: " + self._write_version)
124
125  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
126                   restore_sequentially):
127    """Restore all tensors contained in saveables.
128
129    By default, this issues separate calls to `restore_op` for each saveable.
130    Subclasses may override to load multiple saveables in a single call.
131
132    Args:
133      filename_tensor: String Tensor.
134      saveables: List of BaseSaverBuilder.SaveableObject objects.
135      preferred_shard: Int.  Shard to open first when loading a sharded file.
136      restore_sequentially: Unused.  Bool.  If true, each restore is sequential.
137
138    Returns:
139      A list of Tensors resulting from reading 'saveable' from
140        'filename'.
141
142    """
143    del restore_sequentially
144    all_tensors = []
145    for saveable in saveables:
146      if saveable.device:
147        device = saveable_object_util.set_cpu0(saveable.device)
148      else:
149        device = None
150      with ops.device(device):
151        all_tensors.extend(
152            self.restore_op(filename_tensor, saveable, preferred_shard))
153    return all_tensors
154
155  # pylint: disable=unused-argument
156  def restore_op(self, filename_tensor, saveable, preferred_shard):
157    """Create ops to restore 'saveable'.
158
159    This is intended to be overridden by subclasses that want to generate
160    different Ops.
161
162    Args:
163      filename_tensor: String Tensor.
164      saveable: A BaseSaverBuilder.SaveableObject object.
165      preferred_shard: Int.  Shard to open first when loading a sharded file.
166
167    Returns:
168      A list of Tensors resulting from reading 'saveable' from
169        'filename'.
170    """
171    # pylint: disable=protected-access
172    tensors = []
173    for spec in saveable.specs:
174      tensors.append(
175          io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
176                            [spec.dtype])[0])
177
178    return tensors
179
180  # pylint: enable=unused-argument
181
182  def sharded_filename(self, filename_tensor, shard, num_shards):
183    """Append sharding information to a filename.
184
185    Args:
186      filename_tensor: A string tensor.
187      shard: Integer.  The shard for the filename.
188      num_shards: An int Tensor for the number of shards.
189
190    Returns:
191      A string tensor.
192    """
193    return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
194
195  def _AddSaveOps(self, filename_tensor, saveables):
196    """Add ops to save variables that are on the same shard.
197
198    Args:
199      filename_tensor: String Tensor.
200      saveables: A list of SaveableObject objects.
201
202    Returns:
203      A tensor with the filename used to save.
204    """
205    save = self.save_op(filename_tensor, saveables)
206    return control_flow_ops.with_dependencies([save], filename_tensor)
207
208  def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
209    """Add ops to save the params per shard, for the V2 format.
210
211    Note that the sharded save procedure for the V2 format is different from
212    V1: there is a special "merge" step that merges the small metadata produced
213    from each device.
214
215    Args:
216      checkpoint_prefix: scalar String Tensor.  Interpreted *NOT AS A FILENAME*,
217        but as a prefix of a V2 checkpoint;
218      per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
219        returned by _GroupByDevices().
220
221    Returns:
222      An op to save the variables, which, when evaluated, returns the prefix
223        "<user-fed prefix>" only and does not include the sharded spec suffix.
224    """
225    # IMPLEMENTATION DETAILS: most clients should skip.
226    #
227    # Suffix for any well-formed "checkpoint_prefix", when sharded.
228    # Transformations:
229    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
230    # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
231    # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it
232    # * Otherwise _temp/part is appended which is normalized relative to the OS
233    # Example:
234    #   During runtime, a temporary directory is first created, which contains
235    #   files
236    #
237    #     <train dir>/myckpt_temp/
238    #        part-?????-of-?????{.index, .data-00000-of-00001}
239    #
240    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
241    #
242    #     <train dir>/
243    #        myckpt{.index, .data-?????-of-?????}
244    #
245    #   Filesystems with eventual consistency (such as S3), don't need a
246    #   temporary location. Using a temporary directory in those cases might
247    #   cause situations where files are not available during copy.
248    #
249    # Users only need to interact with the user-specified prefix, which is
250    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
251    # prefix directly, instead of any physical pathname.  (On failure and
252    # subsequent restore, an outdated and orphaned temporary directory can be
253    # safely removed.)
254    with ops.device("CPU"):
255      _SHARDED_SUFFIX = array_ops.where(
256          string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"),
257          constant_op.constant(".part"),
258          constant_op.constant(os.path.normpath("_temp/part")))
259      tmp_checkpoint_prefix = string_ops.string_join(
260          [checkpoint_prefix, _SHARDED_SUFFIX])
261
262    num_shards = len(per_device)
263    sharded_saves = []
264    sharded_prefixes = []
265    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
266    last_device = None
267    for shard, (device, saveables) in enumerate(per_device):
268      last_device = device
269      with ops.device(saveable_object_util.set_cpu0(device)):
270        sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
271                                                 num_shards_tensor)
272        sharded_prefixes.append(sharded_filename)
273        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
274
275    with ops.control_dependencies([x.op for x in sharded_saves]):
276      # Co-locates the merge step with the last device.
277      with ops.device(saveable_object_util.set_cpu0(last_device)):
278        # V2 format write path consists of a metadata merge step.  Once merged,
279        # attempts to delete the temporary directory, "<user-fed prefix>_temp".
280        merge_step = gen_io_ops.merge_v2_checkpoints(
281            sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
282        with ops.control_dependencies([merge_step]):
283          # Returns the prefix "<user-fed prefix>" only.  DOES NOT include the
284          # sharded spec suffix.
285          return array_ops.identity(checkpoint_prefix)
286
287  def _AddShardedSaveOps(self, filename_tensor, per_device):
288    """Add ops to save the params per shard.
289
290    Args:
291      filename_tensor: a scalar String Tensor.
292      per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
293        returned by _GroupByDevices().
294
295    Returns:
296      An op to save the variables.
297    """
298    if self._write_version == saver_pb2.SaverDef.V2:
299      return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
300
301    num_shards = len(per_device)
302    sharded_saves = []
303    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
304    for shard, (device, saveables) in enumerate(per_device):
305      with ops.device(device):
306        sharded_filename = self.sharded_filename(filename_tensor, shard,
307                                                 num_shards_tensor)
308        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
309    # Return the sharded name for the save path.
310    with ops.control_dependencies([x.op for x in sharded_saves]):
311      return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
312
313  def _AddRestoreOps(self,
314                     filename_tensor,
315                     saveables,
316                     restore_sequentially,
317                     reshape,
318                     preferred_shard=-1,
319                     name="restore_all"):
320    """Add operations to restore saveables.
321
322    Args:
323      filename_tensor: Tensor for the path of the file to load.
324      saveables: A list of SaveableObject objects.
325      restore_sequentially: True if we want to restore variables sequentially
326        within a shard.
327      reshape: True if we want to reshape loaded tensors to the shape of the
328        corresponding variable.
329      preferred_shard: Shard to open first when loading a sharded file.
330      name: Name for the returned op.
331
332    Returns:
333      An Operation that restores the variables.
334    """
335    all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
336                                    restore_sequentially)
337
338    assign_ops = []
339    idx = 0
340    # Load and optionally reshape on the CPU, as string tensors are not
341    # available on the GPU.
342    # TODO(touts): Re-enable restore on GPU when we can support annotating
343    # string tensors as "HostMemory" inputs.
344    for saveable in saveables:
345      shapes = None
346      if reshape:
347        # Compute the shapes, let the restore op decide if and how to do
348        # the reshape.
349        shapes = []
350        for spec in saveable.specs:
351          v = spec.tensor
352          shape = v.get_shape()
353          if not shape.is_fully_defined():
354            shape = array_ops.shape(v)
355          shapes.append(shape)
356      saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
357      idx += len(saveable.specs)
358      assign_ops.append(saveable.restore(saveable_tensors, shapes))
359
360    # Create a Noop that has control dependencies from all the updates.
361    return control_flow_ops.group(*assign_ops, name=name)
362
363  def _AddShardedRestoreOps(self, filename_tensor, per_device,
364                            restore_sequentially, reshape):
365    """Add Ops to restore variables from multiple devices.
366
367    Args:
368      filename_tensor: Tensor for the path of the file to load.
369      per_device: A list of (device, SaveableObject) pairs, as returned by
370        _GroupByDevices().
371      restore_sequentially: True if we want to restore variables sequentially
372        within a shard.
373      reshape: True if we want to reshape loaded tensors to the shape of the
374        corresponding variable.
375
376    Returns:
377      An Operation that restores the variables.
378    """
379    sharded_restores = []
380    for shard, (device, saveables) in enumerate(per_device):
381      with ops.device(device):
382        sharded_restores.append(
383            self._AddRestoreOps(
384                filename_tensor,
385                saveables,
386                restore_sequentially,
387                reshape,
388                preferred_shard=shard,
389                name="restore_shard"))
390    return control_flow_ops.group(*sharded_restores, name="restore_all")
391
392  def _GroupByDevices(self, saveables):
393    """Group Variable tensor slices per device.
394
395    TODO(touts): Make sure that all the devices found are on different
396    job/replica/task/cpu|gpu.  It would be bad if 2 were on the same device.
397    It can happen if the devices are unspecified.
398
399    Args:
400      saveables: A list of BaseSaverBuilder.SaveableObject objects.
401
402    Returns:
403      A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
404      The list is sorted by ascending device_name.
405
406    Raises:
407      ValueError: If the tensors of a saveable are on different devices.
408    """
409    per_device = collections.defaultdict(lambda: [])
410    for saveable in saveables:
411      canonical_device = set(
412          pydev.canonical_name(spec.device) for spec in saveable.specs)
413      if len(canonical_device) != 1:
414        raise ValueError("All tensors of a saveable object must be "
415                         "on the same device: %s" % saveable.name)
416      per_device[canonical_device.pop()].append(saveable)
417    return sorted(per_device.items(), key=lambda t: t[0])
418
419  def build(self,
420            names_to_saveables,
421            reshape=False,
422            sharded=False,
423            max_to_keep=5,
424            keep_checkpoint_every_n_hours=10000.0,
425            name=None,
426            restore_sequentially=False,
427            filename="model"):
428    """Builds save/restore graph nodes or runs save/restore in eager mode.
429
430    Args:
431      names_to_saveables: A dictionary mapping name to a Variable or
432        SaveableObject. Each name will be associated with the corresponding
433        variable in the checkpoint.
434      reshape: If True, allow restoring parameters from a checkpoint that where
435        the parameters have a different shape.  This is only needed when you try
436        to restore from a Dist-Belief checkpoint, and only some times.
437      sharded: If True, shard the checkpoints, one per device that has Variable
438        nodes.
439      max_to_keep: Maximum number of checkpoints to keep.  As new checkpoints
440        are created, old ones are deleted.  If None or 0, no checkpoints are
441        deleted from the filesystem but only the last one is kept in the
442        `checkpoint` file.  Presently the number is only roughly enforced.  For
443        example in case of restarts more than max_to_keep checkpoints may be
444        kept.
445      keep_checkpoint_every_n_hours: How often checkpoints should be kept.
446        Defaults to 10,000 hours.
447      name: String.  Optional name to use as a prefix when adding operations.
448      restore_sequentially: A Bool, which if true, causes restore of different
449        variables to happen sequentially within each device.
450      filename: If known at graph construction time, filename used for variable
451        loading/saving. If None, then the default name "model" will be used.
452
453    Returns:
454      A SaverDef proto.
455
456    Raises:
457      TypeError: If 'names_to_saveables' is not a dictionary mapping string
458        keys to variable Tensors.
459      ValueError: If any of the keys or values in 'names_to_saveables' is not
460        unique.
461    """
462    return self._build_internal(
463        names_to_saveables=names_to_saveables,
464        reshape=reshape,
465        sharded=sharded,
466        max_to_keep=max_to_keep,
467        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
468        name=name,
469        restore_sequentially=restore_sequentially,
470        filename=filename)
471
472  def _build_internal(self,
473                      names_to_saveables,
474                      reshape=False,
475                      sharded=False,
476                      max_to_keep=5,
477                      keep_checkpoint_every_n_hours=10000.0,
478                      name=None,
479                      restore_sequentially=False,
480                      filename="model",
481                      build_save=True,
482                      build_restore=True):
483    """build() with option to only perform save and restore."""
484    if not context.executing_eagerly() and (not build_save or
485                                            not build_restore):
486      raise ValueError("save and restore operations need to be built together "
487                       " when eager execution is not enabled.")
488
489    saveables = saveable_object_util.validate_and_slice_inputs(
490        names_to_saveables)
491    if max_to_keep is None:
492      max_to_keep = 0
493
494    with ops.name_scope(name, "save",
495                        [saveable.op for saveable in saveables]) as name:
496      # Add a placeholder string tensor for the filename.
497      filename_tensor = array_ops.placeholder_with_default(
498          filename or "model", shape=(), name="filename")
499      # Keep the name "Const" for backwards compatibility.
500      filename_tensor = array_ops.placeholder_with_default(
501          filename_tensor, shape=(), name="Const")
502
503      # Add the save ops.
504      if sharded:
505        per_device = self._GroupByDevices(saveables)
506        if build_save:
507          save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
508        if build_restore:
509          restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
510                                                  restore_sequentially, reshape)
511      else:
512        if build_save:
513          save_tensor = self._AddSaveOps(filename_tensor, saveables)
514        if build_restore:
515          restore_op = self._AddRestoreOps(filename_tensor, saveables,
516                                           restore_sequentially, reshape)
517
518    # In the following use case, it's possible to have restore_ops be called
519    # something else:
520    # - Build inference graph and export a meta_graph.
521    # - Import the inference meta_graph
522    # - Extend the inference graph to a train graph.
523    # - Export a new meta_graph.
524    # Now the second restore_op will be called "restore_all_1".
525    # As such, comment out the assert for now until we know whether supporting
526    # such usage model makes sense.
527    #
528    # assert restore_op.name.endswith("restore_all"), restore_op.name
529    if context.executing_eagerly():
530      # Store the tensor values to the tensor_names.
531      save_tensor_name = save_tensor.numpy() if build_save else ""
532      return saver_pb2.SaverDef(
533          filename_tensor_name=filename_tensor.numpy(),
534          save_tensor_name=save_tensor_name,
535          restore_op_name="",
536          max_to_keep=max_to_keep,
537          sharded=sharded,
538          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
539          version=self._write_version)
540    else:
541      graph = ops.get_default_graph()
542      # Do some sanity checking on collections containing
543      # PartitionedVariables. If a saved collection has a PartitionedVariable,
544      # the GraphDef needs to include concat ops to get the value (or there'll
545      # be a lookup error on load).
546      check_collection_list = graph.get_all_collection_keys()
547      for collection_type in check_collection_list:
548        for element in graph.get_collection(collection_type):
549          if isinstance(element, variables.PartitionedVariable):
550            try:
551              graph.get_operation_by_name(element.name)
552            except KeyError:
553              # Create a concat op for this PartitionedVariable. The user may
554              # not need it, but we'll try looking it up on MetaGraph restore
555              # since it's in a collection.
556              element.as_tensor()
557      return saver_pb2.SaverDef(
558          filename_tensor_name=filename_tensor.name,
559          save_tensor_name=save_tensor.name,
560          restore_op_name=restore_op.name,
561          max_to_keep=max_to_keep,
562          sharded=sharded,
563          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
564          version=self._write_version)
565
566
567class BulkSaverBuilder(BaseSaverBuilder):
568  """SaverBuilder with support for bulk restoring multiple saveables."""
569
570  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
571                   restore_sequentially):
572
573    # Ignored: bulk restore is internally sequential.
574    del restore_sequentially
575    restore_specs = []
576    for saveable in saveables:
577      for spec in saveable.specs:
578        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
579
580    names, slices, dtypes = zip(*restore_specs)
581    # Load all tensors onto CPU 0 for compatibility with existing code.
582    with ops.device("cpu:0"):
583      return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
584
585
586def _get_saver_or_default():
587  """Returns the saver from SAVERS collection, or creates a default one.
588
589  This method is used by other members of the training module, such as
590  `Scaffold`, or `CheckpointSaverHook`.
591
592  Returns:
593    `Saver`.
594
595  Raises:
596    RuntimeError: If the SAVERS collection already has more than one items.
597  """
598  collection_key = ops.GraphKeys.SAVERS
599  savers = ops.get_collection(collection_key)
600  if savers:
601    if len(savers) > 1:
602      raise RuntimeError(
603          "More than one item in collection {}. "
604          "Please indicate which one to use by passing it to the constructor."
605          .format(collection_key))
606    return savers[0]
607  saver = Saver(sharded=True, allow_empty=True)
608  if saver is not None:
609    ops.add_to_collection(collection_key, saver)
610  return saver
611
612
613@tf_export(v1=["train.Saver"])
614class Saver(object):
615  """Saves and restores variables.
616
617  See [Variables](https://tensorflow.org/guide/variables)
618  for an overview of variables, saving and restoring.
619
620  The `Saver` class adds ops to save and restore variables to and from
621  *checkpoints*.  It also provides convenience methods to run these ops.
622
623  Checkpoints are binary files in a proprietary format which map variable names
624  to tensor values.  The best way to examine the contents of a checkpoint is to
625  load it using a `Saver`.
626
627  Savers can automatically number checkpoint filenames with a provided counter.
628  This lets you keep multiple checkpoints at different steps while training a
629  model.  For example you can number the checkpoint filenames with the training
630  step number.  To avoid filling up disks, savers manage checkpoint files
631  automatically. For example, they can keep only the N most recent files, or
632  one checkpoint for every N hours of training.
633
634  You number checkpoint filenames by passing a value to the optional
635  `global_step` argument to `save()`:
636
637  ```python
638  saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
639  ...
640  saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
641  ```
642
643  Additionally, optional arguments to the `Saver()` constructor let you control
644  the proliferation of checkpoint files on disk:
645
646  * `max_to_keep` indicates the maximum number of recent checkpoint files to
647    keep.  As new files are created, older files are deleted.   If None or 0,
648    no checkpoints are deleted from the filesystem but only the last one is
649    kept in the `checkpoint` file.  Defaults to 5 (that is, the 5 most recent
650    checkpoint files are kept.)
651
652  * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
653    `max_to_keep` checkpoint files, you might want to keep one checkpoint file
654    for every N hours of training.  This can be useful if you want to later
655    analyze how a model progressed during a long training session.  For
656    example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
657    one checkpoint file for every 2 hours of training.  The default value of
658    10,000 hours effectively disables the feature.
659
660  Note that you still have to call the `save()` method to save the model.
661  Passing these arguments to the constructor will not save variables
662  automatically for you.
663
664  A training program that saves regularly looks like:
665
666  ```python
667  ...
668  # Create a saver.
669  saver = tf.compat.v1.train.Saver(...variables...)
670  # Launch the graph and train, saving the model every 1,000 steps.
671  sess = tf.compat.v1.Session()
672  for step in xrange(1000000):
673      sess.run(..training_op..)
674      if step % 1000 == 0:
675          # Append the step number to the checkpoint name:
676          saver.save(sess, 'my-model', global_step=step)
677  ```
678
679  In addition to checkpoint files, savers keep a protocol buffer on disk with
680  the list of recent checkpoints. This is used to manage numbered checkpoint
681  files and by `latest_checkpoint()`, which makes it easy to discover the path
682  to the most recent checkpoint. That protocol buffer is stored in a file named
683  'checkpoint' next to the checkpoint files.
684
685  If you create several savers, you can specify a different filename for the
686  protocol buffer file in the call to `save()`.
687  """
688
689  def __init__(self,
690               var_list=None,
691               reshape=False,
692               sharded=False,
693               max_to_keep=5,
694               keep_checkpoint_every_n_hours=10000.0,
695               name=None,
696               restore_sequentially=False,
697               saver_def=None,
698               builder=None,
699               defer_build=False,
700               allow_empty=False,
701               write_version=saver_pb2.SaverDef.V2,
702               pad_step_number=False,
703               save_relative_paths=False,
704               filename=None):
705    """Creates a `Saver`.
706
707    The constructor adds ops to save and restore variables.
708
709    `var_list` specifies the variables that will be saved and restored. It can
710    be passed as a `dict` or a list:
711
712    * A `dict` of names to variables: The keys are the names that will be
713      used to save or restore the variables in the checkpoint files.
714    * A list of variables: The variables will be keyed with their op name in
715      the checkpoint files.
716
717    For example:
718
719    ```python
720    v1 = tf.Variable(..., name='v1')
721    v2 = tf.Variable(..., name='v2')
722
723    # Pass the variables as a dict:
724    saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
725
726    # Or pass them as a list.
727    saver = tf.compat.v1.train.Saver([v1, v2])
728    # Passing a list is equivalent to passing a dict with the variable op names
729    # as keys:
730    saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
731    ```
732
733    Note: the newer `AutoTrackable` API is not supported by `Saver`. In this
734    case, the `tf.train.Checkpoint` class should be used.
735
736    The optional `reshape` argument, if `True`, allows restoring a variable from
737    a save file where the variable had a different shape, but the same number
738    of elements and type.  This is useful if you have reshaped a variable and
739    want to reload it from an older checkpoint.
740
741    The optional `sharded` argument, if `True`, instructs the saver to shard
742    checkpoints per device.
743
744    Args:
745      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
746        names to `SaveableObject`s. If `None`, defaults to the list of all
747        saveable objects.
748      reshape: If `True`, allows restoring parameters from a checkpoint where
749        the variables have a different shape.
750      sharded: If `True`, shard the checkpoints, one per device.
751      max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
752      keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
753        10,000 hours.
754      name: String.  Optional name to use as a prefix when adding operations.
755      restore_sequentially: A `Bool`, which if true, causes restore of different
756        variables to happen sequentially within each device.  This can lower
757        memory usage when restoring very large models.
758      saver_def: Optional `SaverDef` proto to use instead of running the
759        builder. This is only useful for specialty code that wants to recreate a
760        `Saver` object for a previously built `Graph` that had a `Saver`. The
761        `saver_def` proto should be the one returned by the `as_saver_def()`
762        call of the `Saver` that was created for that `Graph`.
763      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
764        Defaults to `BulkSaverBuilder()`.
765      defer_build: If `True`, defer adding the save and restore ops to the
766        `build()` call. In that case `build()` should be called before
767        finalizing the graph or using the saver.
768      allow_empty: If `False` (default) raise an error if there are no variables
769        in the graph. Otherwise, construct the saver anyway and make it a no-op.
770      write_version: controls what format to use when saving checkpoints.  It
771        also affects certain filepath matching logic.  The V2 format is the
772        recommended choice: it is much more optimized than V1 in terms of memory
773          required and latency incurred during restore.  Regardless of this
774          flag, the Saver is able to restore from both V2 and V1 checkpoints.
775      pad_step_number: if True, pads the global step number in the checkpoint
776        filepaths to some fixed width (8 by default).  This is turned off by
777        default.
778      save_relative_paths: If `True`, will write relative paths to the
779        checkpoint state file. This is needed if the user wants to copy the
780        checkpoint directory and reload from the copied directory.
781      filename: If known at graph construction time, filename used for variable
782        loading/saving.
783
784    Raises:
785      TypeError: If `var_list` is invalid.
786      ValueError: If any of the keys or values in `var_list` are not unique.
787      RuntimeError: If eager execution is enabled and`var_list` does not specify
788        a list of variables to save.
789
790    @compatibility(eager)
791    When eager execution is enabled, `var_list` must specify a `list` or `dict`
792    of variables to save. Otherwise, a `RuntimeError` will be raised.
793
794    Although Saver works in some cases when executing eagerly, it is
795    fragile. Please switch to `tf.train.Checkpoint` or
796    `tf.keras.Model.save_weights`, which perform a more robust object-based
797    saving. These APIs will load checkpoints written by `Saver`.
798    @end_compatibility
799    """
800    if defer_build and var_list:
801      raise ValueError(
802          "If `var_list` is provided then build cannot be deferred. "
803          "Either set defer_build=False or var_list=None.")
804    if context.executing_eagerly():
805      logging.warning(
806          "Saver is deprecated, please switch to tf.train.Checkpoint or "
807          "tf.keras.Model.save_weights for training checkpoints. When "
808          "executing eagerly variables do not necessarily have unique names, "
809          "and so the variable.name-based lookups Saver performs are "
810          "error-prone.")
811      if var_list is None:
812        raise RuntimeError(
813            "When eager execution is enabled, `var_list` must specify a list "
814            "or dict of variables to save")
815    self._var_list = var_list
816    self._reshape = reshape
817    self._sharded = sharded
818    self._max_to_keep = max_to_keep
819    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
820    self._name = name
821    self._restore_sequentially = restore_sequentially
822    self.saver_def = saver_def
823    self._builder = builder
824    self._is_built = False
825    self._allow_empty = allow_empty
826    self._is_empty = None
827    self._write_version = write_version
828    self._pad_step_number = pad_step_number
829    self._filename = filename
830    self._last_checkpoints = []
831    self._checkpoints_to_be_deleted = []
832    if context.executing_eagerly():
833      self._next_checkpoint_time = (
834          time.time() + self._keep_checkpoint_every_n_hours * 3600)
835    elif not defer_build:
836      self.build()
837    if self.saver_def:
838      self._check_saver_def()
839      self._write_version = self.saver_def.version
840    self._save_relative_paths = save_relative_paths
841    # For compatibility with object-based checkpoints, we may build a second
842    # Saver to read the renamed keys.
843    self._object_restore_saver = None
844
845  def build(self):
846    if context.executing_eagerly():
847      raise RuntimeError("Use save/restore instead of build in eager mode.")
848    self._build(self._filename, build_save=True, build_restore=True)
849
850  def _build_eager(self, checkpoint_path, build_save, build_restore):
851    self._build(
852        checkpoint_path, build_save=build_save, build_restore=build_restore)
853
854  def _build(self, checkpoint_path, build_save, build_restore):
855    """Builds saver_def."""
856    if not context.executing_eagerly():
857      if self._is_built:
858        return
859      self._is_built = True
860
861    if not self.saver_def or context.executing_eagerly():
862      if self._builder is None:
863        self._builder = BulkSaverBuilder(self._write_version)
864
865      if self._var_list is None:
866        # pylint: disable=protected-access
867        self._var_list = variables._all_saveable_objects()
868      if not self._var_list:
869        if self._allow_empty:
870          self._is_empty = True
871          return
872        else:
873          raise ValueError("No variables to save")
874      self._is_empty = False
875
876      self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
877          self._var_list,
878          reshape=self._reshape,
879          sharded=self._sharded,
880          max_to_keep=self._max_to_keep,
881          keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
882          name=self._name,
883          restore_sequentially=self._restore_sequentially,
884          filename=checkpoint_path,
885          build_save=build_save,
886          build_restore=build_restore)
887    elif self.saver_def and self._name:
888      # Since self._name is used as a name_scope by builder(), we are
889      # overloading the use of this field to represent the "import_scope" as
890      # well.
891      self.saver_def.filename_tensor_name = ops.prepend_name_scope(
892          self.saver_def.filename_tensor_name, self._name)
893      self.saver_def.save_tensor_name = ops.prepend_name_scope(
894          self.saver_def.save_tensor_name, self._name)
895      self.saver_def.restore_op_name = ops.prepend_name_scope(
896          self.saver_def.restore_op_name, self._name)
897
898    self._check_saver_def()
899    if not context.executing_eagerly():
900      # Updates next checkpoint time.
901      # Set in __init__ when executing eagerly.
902      self._next_checkpoint_time = (
903          time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
904
905  def _check_saver_def(self):
906    if not isinstance(self.saver_def, saver_pb2.SaverDef):
907      raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
908                       self.saver_def)
909    if not context.executing_eagerly():
910      if not self.saver_def.save_tensor_name:
911        raise ValueError("saver_def must specify the save_tensor_name: %s" %
912                         str(self.saver_def))
913      if not self.saver_def.restore_op_name:
914        raise ValueError("saver_def must specify the restore_op_name: %s" %
915                         str(self.saver_def))
916
917  def _CheckpointFilename(self, p):
918    """Returns the checkpoint filename given a `(filename, time)` pair.
919
920    Args:
921      p: (filename, time) pair.
922
923    Returns:
924      Checkpoint file name.
925    """
926    name, _ = p
927    return name
928
929  def _RecordLastCheckpoint(self, latest_save_path):
930    """Manages the list of the latest checkpoints."""
931    if not self.saver_def.max_to_keep:
932      return
933    # Remove first from list if the same name was used before.
934    for p in self._last_checkpoints:
935      if latest_save_path == self._CheckpointFilename(p):
936        self._last_checkpoints.remove(p)
937    # Append new path to list
938    self._last_checkpoints.append((latest_save_path, time.time()))
939
940    # If more than max_to_keep, remove oldest.
941    if len(self._last_checkpoints) > self.saver_def.max_to_keep:
942      self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
943
944  def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
945    """Deletes old checkpoints if necessary.
946
947    `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
948    over `max_to_keep`.  They are going to be deleted.  If
949    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
950    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
951    kept for every 0.5 hours of training; if `N` is 10, an additional
952    checkpoint is kept for every 10 hours of training.
953
954    Args:
955      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
956    """
957    if self._checkpoints_to_be_deleted:
958      p = self._checkpoints_to_be_deleted.pop(0)
959      # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
960      # have reached N hours of training.
961      should_keep = p[1] > self._next_checkpoint_time
962      if should_keep:
963        self._next_checkpoint_time += (
964            self.saver_def.keep_checkpoint_every_n_hours * 3600)
965        return
966
967      # Otherwise delete the files.
968      try:
969        checkpoint_management.remove_checkpoint(
970            self._CheckpointFilename(p), self.saver_def.version,
971            meta_graph_suffix)
972      except Exception as e:  # pylint: disable=broad-except
973        logging.warning("Ignoring: %s", str(e))
974
975  def as_saver_def(self):
976    """Generates a `SaverDef` representation of this saver.
977
978    Returns:
979      A `SaverDef` proto.
980    """
981    return self.saver_def
982
983  def to_proto(self, export_scope=None):
984    """Converts this `Saver` to a `SaverDef` protocol buffer.
985
986    Args:
987      export_scope: Optional `string`. Name scope to remove.
988
989    Returns:
990      A `SaverDef` protocol buffer.
991    """
992    if export_scope is None:
993      return self.saver_def
994
995    if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
996            self.saver_def.save_tensor_name.startswith(export_scope) and
997            self.saver_def.restore_op_name.startswith(export_scope)):
998      return None
999
1000    saver_def = saver_pb2.SaverDef()
1001    saver_def.CopyFrom(self.saver_def)
1002    saver_def.filename_tensor_name = ops.strip_name_scope(
1003        saver_def.filename_tensor_name, export_scope)
1004    saver_def.save_tensor_name = ops.strip_name_scope(
1005        saver_def.save_tensor_name, export_scope)
1006    saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
1007                                                     export_scope)
1008    return saver_def
1009
1010  @staticmethod
1011  def from_proto(saver_def, import_scope=None):
1012    """Returns a `Saver` object created from `saver_def`.
1013
1014    Args:
1015      saver_def: a `SaverDef` protocol buffer.
1016      import_scope: Optional `string`. Name scope to use.
1017
1018    Returns:
1019      A `Saver` built from saver_def.
1020    """
1021    return Saver(saver_def=saver_def, name=import_scope)
1022
1023  @property
1024  def last_checkpoints(self):
1025    """List of not-yet-deleted checkpoint filenames.
1026
1027    You can pass any of the returned values to `restore()`.
1028
1029    Returns:
1030      A list of checkpoint filenames, sorted from oldest to newest.
1031    """
1032    return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
1033
1034  def set_last_checkpoints(self, last_checkpoints):
1035    """DEPRECATED: Use set_last_checkpoints_with_time.
1036
1037    Sets the list of old checkpoint filenames.
1038
1039    Args:
1040      last_checkpoints: A list of checkpoint filenames.
1041
1042    Raises:
1043      AssertionError: If last_checkpoints is not a list.
1044    """
1045    assert isinstance(last_checkpoints, list)
1046    # We use a timestamp of +inf so that this checkpoint will never be
1047    # deleted.  This is both safe and backwards compatible to a previous
1048    # version of the code which used s[1] as the "timestamp".
1049    self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
1050
1051  def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
1052    """Sets the list of old checkpoint filenames and timestamps.
1053
1054    Args:
1055      last_checkpoints_with_time: A list of tuples of checkpoint filenames and
1056        timestamps.
1057
1058    Raises:
1059      AssertionError: If last_checkpoints_with_time is not a list.
1060    """
1061    assert isinstance(last_checkpoints_with_time, list)
1062    self._last_checkpoints = last_checkpoints_with_time
1063
1064  def recover_last_checkpoints(self, checkpoint_paths):
1065    """Recovers the internal saver state after a crash.
1066
1067    This method is useful for recovering the "self._last_checkpoints" state.
1068
1069    Globs for the checkpoints pointed to by `checkpoint_paths`.  If the files
1070    exist, use their mtime as the checkpoint timestamp.
1071
1072    Args:
1073      checkpoint_paths: a list of checkpoint paths.
1074    """
1075    checkpoints_with_mtimes = []
1076    for checkpoint_path in checkpoint_paths:
1077      try:
1078        mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
1079      except errors.NotFoundError:
1080        # It's fine if some other thread/process is deleting some older
1081        # checkpoint concurrently.
1082        continue
1083      if mtime:
1084        checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
1085    self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
1086
1087  def save(self,
1088           sess,
1089           save_path,
1090           global_step=None,
1091           latest_filename=None,
1092           meta_graph_suffix="meta",
1093           write_meta_graph=True,
1094           write_state=True,
1095           strip_default_attrs=False,
1096           save_debug_info=False):
1097    # pylint: disable=line-too-long
1098    """Saves variables.
1099
1100    This method runs the ops added by the constructor for saving variables.
1101    It requires a session in which the graph was launched.  The variables to
1102    save must also have been initialized.
1103
1104    The method returns the path prefix of the newly created checkpoint files.
1105    This string can be passed directly to a call to `restore()`.
1106
1107    Args:
1108      sess: A Session to use to save the variables.
1109      save_path: String.  Prefix of filenames created for the checkpoint.
1110      global_step: If provided the global step number is appended to `save_path`
1111        to create the checkpoint filenames. The optional argument can be a
1112        `Tensor`, a `Tensor` name or an integer.
1113      latest_filename: Optional name for the protocol buffer file that will
1114        contains the list of most recent checkpoints.  That file, kept in the
1115        same directory as the checkpoint files, is automatically managed by the
1116        saver to keep track of recent checkpoints.  Defaults to 'checkpoint'.
1117      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1118      write_meta_graph: `Boolean` indicating whether or not to write the meta
1119        graph file.
1120      write_state: `Boolean` indicating whether or not to write the
1121        `CheckpointStateProto`.
1122      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1123        removed from the NodeDefs. For a detailed guide, see
1124        [Stripping Default-Valued
1125          Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1126      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1127        which in the same directory of save_path and with `_debug` added before
1128        the file extension. This is only enabled when `write_meta_graph` is
1129        `True`
1130
1131    Returns:
1132      A string: path prefix used for the checkpoint files.  If the saver is
1133        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
1134        is the number of shards created.
1135      If the saver is empty, returns None.
1136
1137    Raises:
1138      TypeError: If `sess` is not a `Session`.
1139      ValueError: If `latest_filename` contains path components, or if it
1140        collides with `save_path`.
1141      RuntimeError: If save and restore ops weren't built.
1142    """
1143    # pylint: enable=line-too-long
1144    if not self._is_built and not context.executing_eagerly():
1145      raise RuntimeError(
1146          "`build()` should be called before save if defer_build==True")
1147    if latest_filename is None:
1148      latest_filename = "checkpoint"
1149    if self._write_version != saver_pb2.SaverDef.V2:
1150      logging.warning("*******************************************************")
1151      logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
1152      logging.warning("Consider switching to the more efficient V2 format:")
1153      logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
1154      logging.warning("now on by default.")
1155      logging.warning("*******************************************************")
1156
1157    if os.path.split(latest_filename)[0]:
1158      raise ValueError("'latest_filename' must not contain path components")
1159
1160    save_path = compat.as_str(save_path)
1161    if global_step is not None:
1162      if not isinstance(global_step, compat.integral_types):
1163        global_step = training_util.global_step(sess, global_step)
1164      checkpoint_file = "%s-%d" % (save_path, global_step)
1165      if self._pad_step_number:
1166        # Zero-pads the step numbers, so that they are sorted when listed.
1167        checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
1168    else:
1169      checkpoint_file = save_path
1170      if os.path.basename(save_path) == latest_filename and not self._sharded:
1171        # Guard against collision between data file and checkpoint state file.
1172        raise ValueError(
1173            "'latest_filename' collides with 'save_path': '%s' and '%s'" %
1174            (latest_filename, save_path))
1175
1176    if (not context.executing_eagerly() and
1177        not isinstance(sess, session.SessionInterface)):
1178      raise TypeError("'sess' must be a Session; %s" % sess)
1179
1180    save_path_parent = os.path.dirname(save_path)
1181    if not self._is_empty:
1182      try:
1183        if context.executing_eagerly():
1184          self._build_eager(
1185              checkpoint_file, build_save=True, build_restore=False)
1186          model_checkpoint_path = self.saver_def.save_tensor_name
1187        else:
1188          model_checkpoint_path = sess.run(
1189              self.saver_def.save_tensor_name,
1190              {self.saver_def.filename_tensor_name: checkpoint_file})
1191
1192        model_checkpoint_path = compat.as_str(model_checkpoint_path)
1193        if write_state:
1194          self._RecordLastCheckpoint(model_checkpoint_path)
1195          checkpoint_management.update_checkpoint_state_internal(
1196              save_dir=save_path_parent,
1197              model_checkpoint_path=model_checkpoint_path,
1198              all_model_checkpoint_paths=self.last_checkpoints,
1199              latest_filename=latest_filename,
1200              save_relative_paths=self._save_relative_paths)
1201          self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
1202      except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
1203        if not gfile.IsDirectory(save_path_parent):
1204          exc = ValueError(
1205              "Parent directory of {} doesn't exist, can't save.".format(
1206                  save_path))
1207        raise exc
1208
1209    if write_meta_graph:
1210      meta_graph_filename = checkpoint_management.meta_graph_filename(
1211          checkpoint_file, meta_graph_suffix=meta_graph_suffix)
1212      if not context.executing_eagerly():
1213        with sess.graph.as_default():
1214          self.export_meta_graph(
1215              meta_graph_filename,
1216              strip_default_attrs=strip_default_attrs,
1217              save_debug_info=save_debug_info)
1218
1219    if self._is_empty:
1220      return None
1221    else:
1222      return model_checkpoint_path
1223
1224  def export_meta_graph(self,
1225                        filename=None,
1226                        collection_list=None,
1227                        as_text=False,
1228                        export_scope=None,
1229                        clear_devices=False,
1230                        clear_extraneous_savers=False,
1231                        strip_default_attrs=False,
1232                        save_debug_info=False):
1233    # pylint: disable=line-too-long
1234    """Writes `MetaGraphDef` to save_path/filename.
1235
1236    Args:
1237      filename: Optional meta_graph filename including the path.
1238      collection_list: List of string keys to collect.
1239      as_text: If `True`, writes the meta_graph as an ASCII proto.
1240      export_scope: Optional `string`. Name scope to remove.
1241      clear_devices: Whether or not to clear the device field for an `Operation`
1242        or `Tensor` during export.
1243      clear_extraneous_savers: Remove any Saver-related information from the
1244        graph (both Save/Restore ops and SaverDefs) that are not associated with
1245        this Saver.
1246      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1247        removed from the NodeDefs. For a detailed guide, see
1248        [Stripping Default-Valued
1249          Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1250      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1251        which in the same directory of filename and with `_debug` added before
1252        the file extension.
1253
1254    Returns:
1255      A `MetaGraphDef` proto.
1256    """
1257    # pylint: enable=line-too-long
1258    return export_meta_graph(
1259        filename=filename,
1260        graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
1261        saver_def=self.saver_def,
1262        collection_list=collection_list,
1263        as_text=as_text,
1264        export_scope=export_scope,
1265        clear_devices=clear_devices,
1266        clear_extraneous_savers=clear_extraneous_savers,
1267        strip_default_attrs=strip_default_attrs,
1268        save_debug_info=save_debug_info)
1269
1270  def restore(self, sess, save_path):
1271    """Restores previously saved variables.
1272
1273    This method runs the ops added by the constructor for restoring variables.
1274    It requires a session in which the graph was launched.  The variables to
1275    restore do not have to have been initialized, as restoring is itself a way
1276    to initialize variables.
1277
1278    The `save_path` argument is typically a value previously returned from a
1279    `save()` call, or a call to `latest_checkpoint()`.
1280
1281    Args:
1282      sess: A `Session` to use to restore the parameters. None in eager mode.
1283      save_path: Path where parameters were previously saved.
1284
1285    Raises:
1286      ValueError: If save_path is None or not a valid checkpoint.
1287    """
1288    if self._is_empty:
1289      return
1290    if save_path is None:
1291      raise ValueError("Can't load save_path when it is None.")
1292
1293    checkpoint_prefix = compat.as_text(save_path)
1294    if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix):
1295      raise ValueError("The passed save_path is not a valid checkpoint: " +
1296                       checkpoint_prefix)
1297
1298    logging.info("Restoring parameters from %s", checkpoint_prefix)
1299    try:
1300      if context.executing_eagerly():
1301        self._build_eager(save_path, build_save=False, build_restore=True)
1302      else:
1303        sess.run(self.saver_def.restore_op_name,
1304                 {self.saver_def.filename_tensor_name: save_path})
1305    except errors.NotFoundError as err:
1306      # There are three common conditions that might cause this error:
1307      # 0. The file is missing. We ignore here, as this is checked above.
1308      # 1. This is an object-based checkpoint trying name-based loading.
1309      # 2. The graph has been altered and a variable or other name is missing.
1310
1311      # 1. The checkpoint would not be loaded successfully as is. Try to parse
1312      # it as an object-based checkpoint.
1313      try:
1314        names_to_keys = object_graph_key_mapping(save_path)
1315      except errors.NotFoundError:
1316        # 2. This is not an object-based checkpoint, which likely means there
1317        # is a graph mismatch. Re-raise the original error with
1318        # a helpful message (b/110263146)
1319        raise _wrap_restore_error_with_msg(
1320            err, "a Variable name or other graph key that is missing")
1321
1322      # This is an object-based checkpoint. We'll print a warning and then do
1323      # the restore.
1324      logging.warning(
1325          "Restoring an object-based checkpoint using a name-based saver. This "
1326          "may be somewhat fragile, and will re-build the Saver. Instead, "
1327          "consider loading object-based checkpoints using "
1328          "tf.train.Checkpoint().")
1329      self._object_restore_saver = saver_from_object_based_checkpoint(
1330          checkpoint_path=save_path,
1331          var_list=self._var_list,
1332          builder=self._builder,
1333          names_to_keys=names_to_keys,
1334          cached_saver=self._object_restore_saver)
1335      self._object_restore_saver.restore(sess=sess, save_path=save_path)
1336    except errors.InvalidArgumentError as err:
1337      # There is a mismatch between the graph and the checkpoint being loaded.
1338      # We add a more reasonable error message here to help users (b/110263146)
1339      raise _wrap_restore_error_with_msg(
1340          err, "a mismatch between the current graph and the graph")
1341
1342  @staticmethod
1343  def _add_collection_def(meta_graph_def, key, export_scope=None):
1344    """Adds a collection to MetaGraphDef protocol buffer.
1345
1346    Args:
1347      meta_graph_def: MetaGraphDef protocol buffer.
1348      key: One of the GraphKeys or user-defined string.
1349      export_scope: Optional `string`. Name scope to remove.
1350    """
1351    meta_graph.add_collection_def(
1352        meta_graph_def, key, export_scope=export_scope)
1353
1354
1355@tf_export(v1=["train.import_meta_graph"])
1356def import_meta_graph(meta_graph_or_file,
1357                      clear_devices=False,
1358                      import_scope=None,
1359                      **kwargs):
1360  """Recreates a Graph saved in a `MetaGraphDef` proto.
1361
1362  This function takes a `MetaGraphDef` protocol buffer as input. If
1363  the argument is a file containing a `MetaGraphDef` protocol buffer ,
1364  it constructs a protocol buffer from the file content. The function
1365  then adds all the nodes from the `graph_def` field to the
1366  current graph, recreates all the collections, and returns a saver
1367  constructed from the `saver_def` field.
1368
1369  In combination with `export_meta_graph()`, this function can be used to
1370
1371  * Serialize a graph along with other Python objects such as `QueueRunner`,
1372    `Variable` into a `MetaGraphDef`.
1373
1374  * Restart training from a saved graph and checkpoints.
1375
1376  * Run inference from a saved graph and checkpoints.
1377
1378  ```Python
1379  ...
1380  # Create a saver.
1381  saver = tf.compat.v1.train.Saver(...variables...)
1382  # Remember the training_op we want to run by adding it to a collection.
1383  tf.compat.v1.add_to_collection('train_op', train_op)
1384  sess = tf.compat.v1.Session()
1385  for step in xrange(1000000):
1386      sess.run(train_op)
1387      if step % 1000 == 0:
1388          # Saves checkpoint, which by default also exports a meta_graph
1389          # named 'my-model-global_step.meta'.
1390          saver.save(sess, 'my-model', global_step=step)
1391  ```
1392
1393  Later we can continue training from this saved `meta_graph` without building
1394  the model from scratch.
1395
1396  ```Python
1397  with tf.Session() as sess:
1398    new_saver =
1399    tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
1400    new_saver.restore(sess, 'my-save-dir/my-model-10000')
1401    # tf.get_collection() returns a list. In this example we only want
1402    # the first one.
1403    train_op = tf.get_collection('train_op')[0]
1404    for step in xrange(1000000):
1405      sess.run(train_op)
1406  ```
1407
1408  NOTE: Restarting training from saved `meta_graph` only works if the
1409  device assignments have not changed.
1410
1411  Example:
1412  Variables, placeholders, and independent operations can also be stored, as
1413  shown in the following example.
1414
1415  ```Python
1416  # Saving contents and operations.
1417  v1 = tf.placeholder(tf.float32, name="v1")
1418  v2 = tf.placeholder(tf.float32, name="v2")
1419  v3 = tf.math.multiply(v1, v2)
1420  vx = tf.Variable(10.0, name="vx")
1421  v4 = tf.add(v3, vx, name="v4")
1422  saver = tf.train.Saver([vx])
1423  sess = tf.Session()
1424  sess.run(tf.global_variables_initializer())
1425  sess.run(vx.assign(tf.add(vx, vx)))
1426  result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
1427  print(result)
1428  saver.save(sess, "./model_ex1")
1429  ```
1430
1431  Later this model can be restored and contents loaded.
1432
1433  ```Python
1434  # Restoring variables and running operations.
1435  saver = tf.train.import_meta_graph("./model_ex1.meta")
1436  sess = tf.Session()
1437  saver.restore(sess, "./model_ex1")
1438  result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
1439  print(result)
1440  ```
1441
1442  Args:
1443    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
1444      the path) containing a `MetaGraphDef`.
1445    clear_devices: Whether or not to clear the device field for an `Operation`
1446      or `Tensor` during import.
1447    import_scope: Optional `string`. Name scope to add. Only used when
1448      initializing from protocol buffer.
1449    **kwargs: Optional keyed arguments.
1450
1451  Returns:
1452    A saver constructed from `saver_def` in `MetaGraphDef` or None.
1453
1454    A None value is returned if no variables exist in the `MetaGraphDef`
1455    (i.e., there are no variables to restore).
1456
1457  Raises:
1458    RuntimeError: If called with eager execution enabled.
1459
1460  @compatibility(eager)
1461  Exporting/importing meta graphs is not supported. No graph exists when eager
1462  execution is enabled.
1463  @end_compatibility
1464  """  # pylint: disable=g-doc-exception
1465  return _import_meta_graph_with_return_elements(meta_graph_or_file,
1466                                                 clear_devices, import_scope,
1467                                                 **kwargs)[0]
1468
1469
1470def _import_meta_graph_with_return_elements(meta_graph_or_file,
1471                                            clear_devices=False,
1472                                            import_scope=None,
1473                                            return_elements=None,
1474                                            **kwargs):
1475  """Import MetaGraph, and return both a saver and returned elements."""
1476  if context.executing_eagerly():
1477    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1478                       "eager execution is enabled. No graph exists when eager "
1479                       "execution is enabled.")
1480  if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
1481    meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
1482  else:
1483    meta_graph_def = meta_graph_or_file
1484
1485  imported_vars, imported_return_elements = (
1486      meta_graph.import_scoped_meta_graph_with_return_elements(
1487          meta_graph_def,
1488          clear_devices=clear_devices,
1489          import_scope=import_scope,
1490          return_elements=return_elements,
1491          **kwargs))
1492
1493  saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1494                                                 imported_vars)
1495  return saver, imported_return_elements
1496
1497
1498def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1499                                           imported_vars):
1500  """Return a saver for restoring variable values to an imported MetaGraph."""
1501  if meta_graph_def.HasField("saver_def"):
1502    # Infer the scope that is prepended by `import_scoped_meta_graph`.
1503    scope = import_scope
1504    var_names = list(imported_vars.keys())
1505    if var_names:
1506      sample_key = var_names[0]
1507      sample_var = imported_vars[sample_key]
1508      scope = sample_var.name[:-len(sample_key)]
1509
1510    return Saver(saver_def=meta_graph_def.saver_def, name=scope)
1511  else:
1512    if variables._all_saveable_objects(scope=import_scope):  # pylint: disable=protected-access
1513      # Return the default saver instance for all graph variables.
1514      return Saver()
1515    else:
1516      # If no graph variables exist, then a Saver cannot be constructed.
1517      logging.info("Saver not created because there are no variables in the"
1518                   " graph to restore")
1519      return None
1520
1521
1522@tf_export(v1=["train.export_meta_graph"])
1523def export_meta_graph(filename=None,
1524                      meta_info_def=None,
1525                      graph_def=None,
1526                      saver_def=None,
1527                      collection_list=None,
1528                      as_text=False,
1529                      graph=None,
1530                      export_scope=None,
1531                      clear_devices=False,
1532                      clear_extraneous_savers=False,
1533                      strip_default_attrs=False,
1534                      save_debug_info=False,
1535                      **kwargs):
1536  # pylint: disable=line-too-long
1537  """Returns `MetaGraphDef` proto.
1538
1539  Optionally writes it to filename.
1540
1541  This function exports the graph, saver, and collection objects into
1542  `MetaGraphDef` protocol buffer with the intention of it being imported
1543  at a later time or location to restart training, run inference, or be
1544  a subgraph.
1545
1546  Args:
1547    filename: Optional filename including the path for writing the generated
1548      `MetaGraphDef` protocol buffer.
1549    meta_info_def: `MetaInfoDef` protocol buffer.
1550    graph_def: `GraphDef` protocol buffer.
1551    saver_def: `SaverDef` protocol buffer.
1552    collection_list: List of string keys to collect.
1553    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
1554    graph: The `Graph` to export. If `None`, use the default graph.
1555    export_scope: Optional `string`. Name scope under which to extract the
1556      subgraph. The scope name will be striped from the node definitions for
1557      easy import later into new name scopes. If `None`, the whole graph is
1558      exported. graph_def and export_scope cannot both be specified.
1559    clear_devices: Whether or not to clear the device field for an `Operation`
1560      or `Tensor` during export.
1561    clear_extraneous_savers: Remove any Saver-related information from the graph
1562      (both Save/Restore ops and SaverDefs) that are not associated with the
1563      provided SaverDef.
1564    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1565      removed from the NodeDefs. For a detailed guide, see
1566      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1567    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1568      which in the same directory of filename and with `_debug` added before the
1569      file extend.
1570    **kwargs: Optional keyed arguments.
1571
1572  Returns:
1573    A `MetaGraphDef` proto.
1574
1575  Raises:
1576    ValueError: When the `GraphDef` is larger than 2GB.
1577    RuntimeError: If called with eager execution enabled.
1578
1579  @compatibility(eager)
1580  Exporting/importing meta graphs is not supported unless both `graph_def` and
1581  `graph` are provided. No graph exists when eager execution is enabled.
1582  @end_compatibility
1583  """
1584  # pylint: enable=line-too-long
1585  if context.executing_eagerly() and not (graph_def is not None and
1586                                          graph is not None):
1587    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1588                       "eager execution is enabled. No graph exists when eager "
1589                       "execution is enabled.")
1590  meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
1591      filename=filename,
1592      meta_info_def=meta_info_def,
1593      graph_def=graph_def,
1594      saver_def=saver_def,
1595      collection_list=collection_list,
1596      as_text=as_text,
1597      graph=graph,
1598      export_scope=export_scope,
1599      clear_devices=clear_devices,
1600      clear_extraneous_savers=clear_extraneous_savers,
1601      strip_default_attrs=strip_default_attrs,
1602      save_debug_info=save_debug_info,
1603      **kwargs)
1604  return meta_graph_def
1605
1606
1607def _wrap_restore_error_with_msg(err, extra_verbiage):
1608  err_msg = ("Restoring from checkpoint failed. This is most likely "
1609             "due to {} from the checkpoint. Please ensure that you "
1610             "have not altered the graph expected based on the checkpoint. "
1611             "Original error:\n\n{}").format(extra_verbiage, err.message)
1612  return err.__class__(err.node_def, err.op, err_msg)
1613
1614
1615ops.register_proto_function(
1616    ops.GraphKeys.SAVERS,
1617    proto_type=saver_pb2.SaverDef,
1618    to_proto=Saver.to_proto,
1619    from_proto=Saver.from_proto)
1620
1621
1622def object_graph_key_mapping(checkpoint_path):
1623  """Return name to key mappings from the checkpoint.
1624
1625  Args:
1626    checkpoint_path: string, path to object-based checkpoint
1627
1628  Returns:
1629    Dictionary mapping tensor names to checkpoint keys.
1630  """
1631  reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
1632  object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
1633  object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1634  object_graph_proto.ParseFromString(object_graph_string)
1635  names_to_keys = {}
1636  for node in object_graph_proto.nodes:
1637    for attribute in node.attributes:
1638      names_to_keys[attribute.full_name] = attribute.checkpoint_key
1639  return names_to_keys
1640
1641
1642def saver_from_object_based_checkpoint(checkpoint_path,
1643                                       var_list=None,
1644                                       builder=None,
1645                                       names_to_keys=None,
1646                                       cached_saver=None):
1647  """Return a `Saver` which reads from an object-based checkpoint.
1648
1649  This function validates that all variables in the variables list are remapped
1650  in the object-based checkpoint (or `names_to_keys` dict if provided). A
1651  saver will be created with the list of remapped variables.
1652
1653  The `cached_saver` argument allows the user to pass in a previously created
1654  saver, so multiple `saver.restore()` calls don't pollute the graph when graph
1655  building. This assumes that keys are consistent, meaning that the
1656    1) `checkpoint_path` checkpoint, and
1657    2) checkpoint used to create the `cached_saver`
1658  are the same type of object-based checkpoint. If this argument is set, this
1659  function will simply validate that all variables have been remapped by the
1660  checkpoint at `checkpoint_path`.
1661
1662  Note that in general, `tf.train.Checkpoint` should be used to restore/save an
1663  object-based checkpoint.
1664
1665  Args:
1666    checkpoint_path: string, path to object-based checkpoint
1667    var_list: list of `Variables` that appear in the checkpoint. If `None`,
1668      `var_list` will be set to all saveable objects.
1669    builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder`
1670      will be created.
1671    names_to_keys: dict mapping string tensor names to checkpoint keys. If
1672      `None`, this dict will be generated from the checkpoint file.
1673    cached_saver: Cached `Saver` object with remapped variables.
1674
1675  Returns:
1676    `Saver` with remapped variables for reading from an object-based checkpoint.
1677
1678  Raises:
1679    ValueError if the checkpoint provided is not an object-based checkpoint.
1680    NotFoundError: If one of the variables in `var_list` can not be found in the
1681      checkpoint. This could mean the checkpoint or `names_to_keys` mapping is
1682      missing the variable.
1683  """
1684  if names_to_keys is None:
1685    try:
1686      names_to_keys = object_graph_key_mapping(checkpoint_path)
1687    except errors.NotFoundError:
1688      raise ValueError("Checkpoint in %s not an object-based checkpoint." %
1689                       checkpoint_path)
1690  if var_list is None:
1691    var_list = variables._all_saveable_objects()  # pylint: disable=protected-access
1692  if builder is None:
1693    builder = BulkSaverBuilder()
1694
1695  saveables = saveable_object_util.validate_and_slice_inputs(var_list)
1696  current_names = set()
1697  for saveable in saveables:
1698    for spec in saveable.specs:
1699      current_names.add(spec.name)
1700  previous_names = set(names_to_keys.keys())
1701  missing_names = current_names - previous_names
1702  if missing_names:
1703    extra_names = previous_names - current_names
1704    intersecting_names = previous_names.intersection(current_names)
1705    raise errors.NotFoundError(
1706        None,
1707        None,
1708        message=(
1709            "\n\nExisting variables not in the checkpoint: %s\n\n"
1710            "Variables names when this checkpoint was written which don't "
1711            "exist now: %s\n\n"
1712            "(%d variable name(s) did match)\n\n"
1713            "Could not find some variables in the checkpoint (see names "
1714            "above). Saver was attempting to load an object-based checkpoint "
1715            "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) "
1716            "using variable names. If the checkpoint was written with eager "
1717            "execution enabled, it's possible that variable names have "
1718            "changed (for example missing a '_1' suffix). It's also "
1719            "possible that there are new variables which did not exist "
1720            "when the checkpoint was written. You can construct a "
1721            "Saver(var_list=...) with only the variables which previously "
1722            "existed, and if variable names have changed you may need to "
1723            "make this a dictionary with the old names as keys. If you're "
1724            "using an Estimator, you'll need to return a tf.train.Saver "
1725            "inside a tf.train.Scaffold from your model_fn.") %
1726        (", ".join(sorted(missing_names)), ", ".join(
1727            sorted(extra_names)), len(intersecting_names)))
1728  for saveable in saveables:
1729    for spec in saveable.specs:
1730      spec.name = names_to_keys[spec.name]
1731  if cached_saver is None:
1732    return Saver(saveables)
1733  return cached_saver
1734