• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables.
18
19Symbols in this file are deprecated. See replacements in
20tensorflow/python/training/trackable and tensorflow/python/training/saving.
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27import os.path
28import time
29import uuid
30
31import numpy as np
32from tensorflow.core.protobuf import meta_graph_pb2
33from tensorflow.core.protobuf import saver_pb2
34from tensorflow.core.protobuf import trackable_object_graph_pb2
35from tensorflow.python import pywrap_tensorflow
36from tensorflow.python.client import session
37from tensorflow.python.eager import context
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import device as pydev
40from tensorflow.python.framework import errors
41from tensorflow.python.framework import meta_graph
42from tensorflow.python.framework import ops
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import gen_io_ops
46from tensorflow.python.ops import io_ops
47from tensorflow.python.ops import string_ops
48from tensorflow.python.ops import variables
49from tensorflow.python.platform import gfile
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.training import checkpoint_management
52from tensorflow.python.training import training_util
53from tensorflow.python.training.saving import saveable_object
54from tensorflow.python.training.saving import saveable_object_util
55from tensorflow.python.training.tracking import base as trackable
56from tensorflow.python.util import compat
57from tensorflow.python.util.tf_export import tf_export
58
59
60# TODO(allenl): Remove these aliases once all users are migrated off.
61get_checkpoint_state = checkpoint_management.get_checkpoint_state
62update_checkpoint_state = checkpoint_management.update_checkpoint_state
63generate_checkpoint_state_proto = (
64    checkpoint_management.generate_checkpoint_state_proto)
65latest_checkpoint = checkpoint_management.latest_checkpoint
66checkpoint_exists = checkpoint_management.checkpoint_exists
67get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
68remove_checkpoint = checkpoint_management.remove_checkpoint
69
70
71class BaseSaverBuilder(object):
72  """Base class for Savers.
73
74  Can be extended to create different Ops.
75  """
76
77  SaveSpec = saveable_object.SaveSpec
78  SaveableObject = saveable_object.SaveableObject
79
80  # Aliases for code which was moved but still has lots of users.
81  VariableSaveable = saveable_object_util.ReferenceVariableSaveable
82  ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
83
84  def __init__(self, write_version=saver_pb2.SaverDef.V2):
85    self._write_version = write_version
86
87  def save_op(self, filename_tensor, saveables):
88    """Create an Op to save 'saveables'.
89
90    This is intended to be overridden by subclasses that want to generate
91    different Ops.
92
93    Args:
94      filename_tensor: String Tensor.
95      saveables: A list of BaseSaverBuilder.SaveableObject objects.
96
97    Returns:
98      An Operation that save the variables.
99
100    Raises:
101      RuntimeError: (implementation detail) if "self._write_version" is an
102        unexpected value.
103    """
104    # pylint: disable=protected-access
105    tensor_names = []
106    tensors = []
107    tensor_slices = []
108    for saveable in saveables:
109      for spec in saveable.specs:
110        tensor_names.append(spec.name)
111        tensors.append(spec.tensor)
112        tensor_slices.append(spec.slice_spec)
113    if self._write_version == saver_pb2.SaverDef.V1:
114      return io_ops._save(
115          filename=filename_tensor,
116          tensor_names=tensor_names,
117          tensors=tensors,
118          tensor_slices=tensor_slices)
119    elif self._write_version == saver_pb2.SaverDef.V2:
120      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
121      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
122      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
123                            tensors)
124    else:
125      raise RuntimeError("Unexpected write_version: " + self._write_version)
126
127  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
128                   restore_sequentially):
129    """Restore all tensors contained in saveables.
130
131    By default, this issues separate calls to `restore_op` for each saveable.
132    Subclasses may override to load multiple saveables in a single call.
133
134    Args:
135      filename_tensor: String Tensor.
136      saveables: List of BaseSaverBuilder.SaveableObject objects.
137      preferred_shard: Int.  Shard to open first when loading a sharded file.
138      restore_sequentially: Unused.  Bool.  If true, each restore is sequential.
139
140    Returns:
141      A list of Tensors resulting from reading 'saveable' from
142        'filename'.
143
144    """
145    del restore_sequentially
146    all_tensors = []
147    for saveable in saveables:
148      if saveable.device:
149        device = saveable_object_util.set_cpu0(saveable.device)
150      else:
151        device = None
152      with ops.device(device):
153        all_tensors.extend(
154            self.restore_op(filename_tensor, saveable, preferred_shard))
155    return all_tensors
156
157  # pylint: disable=unused-argument
158  def restore_op(self, filename_tensor, saveable, preferred_shard):
159    """Create ops to restore 'saveable'.
160
161    This is intended to be overridden by subclasses that want to generate
162    different Ops.
163
164    Args:
165      filename_tensor: String Tensor.
166      saveable: A BaseSaverBuilder.SaveableObject object.
167      preferred_shard: Int.  Shard to open first when loading a sharded file.
168
169    Returns:
170      A list of Tensors resulting from reading 'saveable' from
171        'filename'.
172    """
173    # pylint: disable=protected-access
174    tensors = []
175    for spec in saveable.specs:
176      tensors.append(
177          io_ops.restore_v2(
178              filename_tensor,
179              [spec.name],
180              [spec.slice_spec],
181              [spec.dtype])[0])
182
183    return tensors
184  # pylint: enable=unused-argument
185
186  def sharded_filename(self, filename_tensor, shard, num_shards):
187    """Append sharding information to a filename.
188
189    Args:
190      filename_tensor: A string tensor.
191      shard: Integer.  The shard for the filename.
192      num_shards: An int Tensor for the number of shards.
193
194    Returns:
195      A string tensor.
196    """
197    return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
198
199  def _AddSaveOps(self, filename_tensor, saveables):
200    """Add ops to save variables that are on the same shard.
201
202    Args:
203      filename_tensor: String Tensor.
204      saveables: A list of SaveableObject objects.
205
206    Returns:
207      A tensor with the filename used to save.
208    """
209    save = self.save_op(filename_tensor, saveables)
210    return control_flow_ops.with_dependencies([save], filename_tensor)
211
212  def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
213    """Add ops to save the params per shard, for the V2 format.
214
215    Note that the sharded save procedure for the V2 format is different from
216    V1: there is a special "merge" step that merges the small metadata produced
217    from each device.
218
219    Args:
220      checkpoint_prefix: scalar String Tensor.  Interpreted *NOT AS A
221        FILENAME*, but as a prefix of a V2 checkpoint;
222      per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
223        returned by _GroupByDevices().
224
225    Returns:
226      An op to save the variables, which, when evaluated, returns the prefix
227        "<user-fed prefix>" only and does not include the sharded spec suffix.
228    """
229    # IMPLEMENTATION DETAILS: most clients should skip.
230    #
231    # Suffix for any well-formed "checkpoint_prefix", when sharded.
232    # Transformations:
233    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
234    # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
235    #
236    # Example:
237    #   During runtime, a temporary directory is first created, which contains
238    #   files
239    #
240    #     <train dir>/myckpt_temp/
241    #        part-?????-of-?????{.index, .data-00000-of-00001}
242    #
243    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
244    #
245    #     <train dir>/
246    #        myckpt{.index, .data-?????-of-?????}
247    #
248    # Users only need to interact with the user-specified prefix, which is
249    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
250    # prefix directly, instead of any physical pathname.  (On failure and
251    # subsequent restore, an outdated and orphaned temporary directory can be
252    # safely removed.)
253    _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex
254    tmp_checkpoint_prefix = string_ops.string_join(
255        [checkpoint_prefix, _SHARDED_SUFFIX])
256
257    num_shards = len(per_device)
258    sharded_saves = []
259    sharded_prefixes = []
260    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
261    last_device = None
262    for shard, (device, saveables) in enumerate(per_device):
263      last_device = device
264      with ops.device(saveable_object_util.set_cpu0(device)):
265        sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
266                                                 num_shards_tensor)
267        sharded_prefixes.append(sharded_filename)
268        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
269
270    with ops.control_dependencies([x.op for x in sharded_saves]):
271      # Co-locates the merge step with the last device.
272      with ops.device(saveable_object_util.set_cpu0(last_device)):
273        # V2 format write path consists of a metadata merge step.  Once merged,
274        # attempts to delete the temporary directory, "<user-fed prefix>_temp".
275        merge_step = gen_io_ops.merge_v2_checkpoints(
276            sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
277        with ops.control_dependencies([merge_step]):
278          # Returns the prefix "<user-fed prefix>" only.  DOES NOT include the
279          # sharded spec suffix.
280          return array_ops.identity(checkpoint_prefix)
281
282  def _AddShardedSaveOps(self, filename_tensor, per_device):
283    """Add ops to save the params per shard.
284
285    Args:
286      filename_tensor: a scalar String Tensor.
287      per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
288        returned by _GroupByDevices().
289
290    Returns:
291      An op to save the variables.
292    """
293    if self._write_version == saver_pb2.SaverDef.V2:
294      return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
295
296    num_shards = len(per_device)
297    sharded_saves = []
298    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
299    for shard, (device, saveables) in enumerate(per_device):
300      with ops.device(device):
301        sharded_filename = self.sharded_filename(filename_tensor, shard,
302                                                 num_shards_tensor)
303        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
304    # Return the sharded name for the save path.
305    with ops.control_dependencies([x.op for x in sharded_saves]):
306      return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
307
308  def _AddRestoreOps(self,
309                     filename_tensor,
310                     saveables,
311                     restore_sequentially,
312                     reshape,
313                     preferred_shard=-1,
314                     name="restore_all"):
315    """Add operations to restore saveables.
316
317    Args:
318      filename_tensor: Tensor for the path of the file to load.
319      saveables: A list of SaveableObject objects.
320      restore_sequentially: True if we want to restore variables sequentially
321        within a shard.
322      reshape: True if we want to reshape loaded tensors to the shape of
323        the corresponding variable.
324      preferred_shard: Shard to open first when loading a sharded file.
325      name: Name for the returned op.
326
327    Returns:
328      An Operation that restores the variables.
329    """
330    all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
331                                    restore_sequentially)
332
333    assign_ops = []
334    idx = 0
335    # Load and optionally reshape on the CPU, as string tensors are not
336    # available on the GPU.
337    # TODO(touts): Re-enable restore on GPU when we can support annotating
338    # string tensors as "HostMemory" inputs.
339    for saveable in saveables:
340      shapes = None
341      if reshape:
342        # Compute the shapes, let the restore op decide if and how to do
343        # the reshape.
344        shapes = []
345        for spec in saveable.specs:
346          v = spec.tensor
347          shape = v.get_shape()
348          if not shape.is_fully_defined():
349            shape = array_ops.shape(v)
350          shapes.append(shape)
351      saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
352      idx += len(saveable.specs)
353      assign_ops.append(saveable.restore(saveable_tensors, shapes))
354
355    # Create a Noop that has control dependencies from all the updates.
356    return control_flow_ops.group(*assign_ops, name=name)
357
358  def _AddShardedRestoreOps(self, filename_tensor, per_device,
359                            restore_sequentially, reshape):
360    """Add Ops to restore variables from multiple devices.
361
362    Args:
363      filename_tensor: Tensor for the path of the file to load.
364      per_device: A list of (device, SaveableObject) pairs, as
365        returned by _GroupByDevices().
366      restore_sequentially: True if we want to restore variables sequentially
367        within a shard.
368      reshape: True if we want to reshape loaded tensors to the shape of
369        the corresponding variable.
370
371    Returns:
372      An Operation that restores the variables.
373    """
374    sharded_restores = []
375    for shard, (device, saveables) in enumerate(per_device):
376      with ops.device(device):
377        sharded_restores.append(
378            self._AddRestoreOps(
379                filename_tensor,
380                saveables,
381                restore_sequentially,
382                reshape,
383                preferred_shard=shard,
384                name="restore_shard"))
385    return control_flow_ops.group(*sharded_restores, name="restore_all")
386
387  def _GroupByDevices(self, saveables):
388    """Group Variable tensor slices per device.
389
390    TODO(touts): Make sure that all the devices found are on different
391    job/replica/task/cpu|gpu.  It would be bad if 2 were on the same device.
392    It can happen if the devices are unspecified.
393
394    Args:
395      saveables: A list of BaseSaverBuilder.SaveableObject objects.
396
397    Returns:
398      A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
399      The list is sorted by ascending device_name.
400
401    Raises:
402      ValueError: If the tensors of a saveable are on different devices.
403    """
404    per_device = collections.defaultdict(lambda: [])
405    for saveable in saveables:
406      canonical_device = set(
407          pydev.canonical_name(spec.tensor.device) for spec in saveable.specs)
408      if len(canonical_device) != 1:
409        raise ValueError("All tensors of a saveable object must be "
410                         "on the same device: %s" % saveable.name)
411      per_device[canonical_device.pop()].append(saveable)
412    return sorted(per_device.items(), key=lambda t: t[0])
413
414  def build(self,
415            names_to_saveables,
416            reshape=False,
417            sharded=False,
418            max_to_keep=5,
419            keep_checkpoint_every_n_hours=10000.0,
420            name=None,
421            restore_sequentially=False,
422            filename="model"):
423    """Builds save/restore graph nodes or runs save/restore in eager mode.
424
425    Args:
426      names_to_saveables: A dictionary mapping name to a Variable or
427        SaveableObject. Each name will be associated with the
428        corresponding variable in the checkpoint.
429      reshape: If True, allow restoring parameters from a checkpoint
430        that where the parameters have a different shape.  This is
431        only needed when you try to restore from a Dist-Belief checkpoint,
432        and only some times.
433      sharded: If True, shard the checkpoints, one per device that has
434        Variable nodes.
435      max_to_keep: Maximum number of checkpoints to keep.  As new checkpoints
436        are created, old ones are deleted.  If None or 0, no checkpoints are
437        deleted from the filesystem but only the last one is kept in the
438        `checkpoint` file.  Presently the number is only roughly enforced.  For
439        example in case of restarts more than max_to_keep checkpoints may be
440        kept.
441      keep_checkpoint_every_n_hours: How often checkpoints should be kept.
442        Defaults to 10,000 hours.
443      name: String.  Optional name to use as a prefix when adding operations.
444      restore_sequentially: A Bool, which if true, causes restore of different
445        variables to happen sequentially within each device.
446      filename: If known at graph construction time, filename used for variable
447        loading/saving. If None, then the default name "model" will be used.
448
449    Returns:
450      A SaverDef proto.
451
452    Raises:
453      TypeError: If 'names_to_saveables' is not a dictionary mapping string
454        keys to variable Tensors.
455      ValueError: If any of the keys or values in 'names_to_saveables' is not
456        unique.
457    """
458    return self._build_internal(
459        names_to_saveables=names_to_saveables,
460        reshape=reshape,
461        sharded=sharded,
462        max_to_keep=max_to_keep,
463        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
464        name=name,
465        restore_sequentially=restore_sequentially,
466        filename=filename)
467
468  def _build_internal(self,
469                      names_to_saveables,
470                      reshape=False,
471                      sharded=False,
472                      max_to_keep=5,
473                      keep_checkpoint_every_n_hours=10000.0,
474                      name=None,
475                      restore_sequentially=False,
476                      filename="model",
477                      build_save=True,
478                      build_restore=True):
479    """build() with option to only perform save and restore."""
480    if not context.executing_eagerly() and (not build_save or
481                                            not build_restore):
482      raise ValueError("save and restore operations need to be built together "
483                       " when eager execution is not enabled.")
484
485    saveables = saveable_object_util.validate_and_slice_inputs(
486        names_to_saveables)
487    if max_to_keep is None:
488      max_to_keep = 0
489
490    with ops.name_scope(name, "save",
491                        [saveable.op for saveable in saveables]) as name:
492      # Add a placeholder string tensor for the filename.
493      filename_tensor = array_ops.placeholder_with_default(
494          filename or "model", shape=(), name="filename")
495      # Keep the name "Const" for backwards compatibility.
496      filename_tensor = array_ops.placeholder_with_default(
497          filename_tensor, shape=(), name="Const")
498
499      # Add the save ops.
500      if sharded:
501        per_device = self._GroupByDevices(saveables)
502        if build_save:
503          save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
504        if build_restore:
505          restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
506                                                  restore_sequentially, reshape)
507      else:
508        if build_save:
509          save_tensor = self._AddSaveOps(filename_tensor, saveables)
510        if build_restore:
511          restore_op = self._AddRestoreOps(filename_tensor, saveables,
512                                           restore_sequentially, reshape)
513
514    # In the following use case, it's possible to have restore_ops be called
515    # something else:
516    # - Build inference graph and export a meta_graph.
517    # - Import the inference meta_graph
518    # - Extend the inference graph to a train graph.
519    # - Export a new meta_graph.
520    # Now the second restore_op will be called "restore_all_1".
521    # As such, comment out the assert for now until we know whether supporting
522    # such usage model makes sense.
523    #
524    # assert restore_op.name.endswith("restore_all"), restore_op.name
525    if context.executing_eagerly():
526      # Store the tensor values to the tensor_names.
527      save_tensor_name = save_tensor.numpy() if build_save else ""
528      return saver_pb2.SaverDef(
529          filename_tensor_name=filename_tensor.numpy(),
530          save_tensor_name=save_tensor_name,
531          restore_op_name="",
532          max_to_keep=max_to_keep,
533          sharded=sharded,
534          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
535          version=self._write_version)
536    else:
537      graph = ops.get_default_graph()
538      # Do some sanity checking on collections containing
539      # PartitionedVariables. If a saved collection has a PartitionedVariable,
540      # the GraphDef needs to include concat ops to get the value (or there'll
541      # be a lookup error on load).
542      check_collection_list = graph.get_all_collection_keys()
543      for collection_type in check_collection_list:
544        for element in graph.get_collection(collection_type):
545          if isinstance(element, variables.PartitionedVariable):
546            try:
547              graph.get_operation_by_name(element.name)
548            except KeyError:
549              # Create a concat op for this PartitionedVariable. The user may
550              # not need it, but we'll try looking it up on MetaGraph restore
551              # since it's in a collection.
552              element.as_tensor()
553      return saver_pb2.SaverDef(
554          filename_tensor_name=filename_tensor.name,
555          save_tensor_name=save_tensor.name,
556          restore_op_name=restore_op.name,
557          max_to_keep=max_to_keep,
558          sharded=sharded,
559          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
560          version=self._write_version)
561
562
563class BulkSaverBuilder(BaseSaverBuilder):
564  """SaverBuilder with support for bulk restoring multiple saveables."""
565
566  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
567                   restore_sequentially):
568
569    # Ignored: bulk restore is internally sequential.
570    del restore_sequentially
571    restore_specs = []
572    for saveable in saveables:
573      for spec in saveable.specs:
574        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
575
576    names, slices, dtypes = zip(*restore_specs)
577    # Load all tensors onto CPU 0 for compatibility with existing code.
578    with ops.device("cpu:0"):
579      return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
580
581
582def _get_saver_or_default():
583  """Returns the saver from SAVERS collection, or creates a default one.
584
585  This method is used by other members of the training module, such as
586  `Scaffold`, or `CheckpointSaverHook`.
587
588  Returns:
589    `Saver`.
590
591  Raises:
592    RuntimeError: If the SAVERS collection already has more than one items.
593  """
594  collection_key = ops.GraphKeys.SAVERS
595  savers = ops.get_collection(collection_key)
596  if savers:
597    if len(savers) > 1:
598      raise RuntimeError(
599          "More than one item in collection {}. "
600          "Please indicate which one to use by passing it to the constructor.".
601          format(collection_key))
602    return savers[0]
603  saver = Saver(sharded=True, allow_empty=True)
604  if saver is not None:
605    ops.add_to_collection(collection_key, saver)
606  return saver
607
608
609@tf_export(v1=["train.Saver"])
610class Saver(object):
611  """Saves and restores variables.
612
613  See [Variables](https://tensorflow.org/guide/variables)
614  for an overview of variables, saving and restoring.
615
616  The `Saver` class adds ops to save and restore variables to and from
617  *checkpoints*.  It also provides convenience methods to run these ops.
618
619  Checkpoints are binary files in a proprietary format which map variable names
620  to tensor values.  The best way to examine the contents of a checkpoint is to
621  load it using a `Saver`.
622
623  Savers can automatically number checkpoint filenames with a provided counter.
624  This lets you keep multiple checkpoints at different steps while training a
625  model.  For example you can number the checkpoint filenames with the training
626  step number.  To avoid filling up disks, savers manage checkpoint files
627  automatically. For example, they can keep only the N most recent files, or
628  one checkpoint for every N hours of training.
629
630  You number checkpoint filenames by passing a value to the optional
631  `global_step` argument to `save()`:
632
633  ```python
634  saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
635  ...
636  saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
637  ```
638
639  Additionally, optional arguments to the `Saver()` constructor let you control
640  the proliferation of checkpoint files on disk:
641
642  * `max_to_keep` indicates the maximum number of recent checkpoint files to
643    keep.  As new files are created, older files are deleted.   If None or 0,
644    no checkpoints are deleted from the filesystem but only the last one is
645    kept in the `checkpoint` file.  Defaults to 5 (that is, the 5 most recent
646    checkpoint files are kept.)
647
648  * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
649    `max_to_keep` checkpoint files, you might want to keep one checkpoint file
650    for every N hours of training.  This can be useful if you want to later
651    analyze how a model progressed during a long training session.  For
652    example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
653    one checkpoint file for every 2 hours of training.  The default value of
654    10,000 hours effectively disables the feature.
655
656  Note that you still have to call the `save()` method to save the model.
657  Passing these arguments to the constructor will not save variables
658  automatically for you.
659
660  A training program that saves regularly looks like:
661
662  ```python
663  ...
664  # Create a saver.
665  saver = tf.train.Saver(...variables...)
666  # Launch the graph and train, saving the model every 1,000 steps.
667  sess = tf.Session()
668  for step in xrange(1000000):
669      sess.run(..training_op..)
670      if step % 1000 == 0:
671          # Append the step number to the checkpoint name:
672          saver.save(sess, 'my-model', global_step=step)
673  ```
674
675  In addition to checkpoint files, savers keep a protocol buffer on disk with
676  the list of recent checkpoints. This is used to manage numbered checkpoint
677  files and by `latest_checkpoint()`, which makes it easy to discover the path
678  to the most recent checkpoint. That protocol buffer is stored in a file named
679  'checkpoint' next to the checkpoint files.
680
681  If you create several savers, you can specify a different filename for the
682  protocol buffer file in the call to `save()`.
683  """
684
685  def __init__(self,
686               var_list=None,
687               reshape=False,
688               sharded=False,
689               max_to_keep=5,
690               keep_checkpoint_every_n_hours=10000.0,
691               name=None,
692               restore_sequentially=False,
693               saver_def=None,
694               builder=None,
695               defer_build=False,
696               allow_empty=False,
697               write_version=saver_pb2.SaverDef.V2,
698               pad_step_number=False,
699               save_relative_paths=False,
700               filename=None):
701    """Creates a `Saver`.
702
703    The constructor adds ops to save and restore variables.
704
705    `var_list` specifies the variables that will be saved and restored. It can
706    be passed as a `dict` or a list:
707
708    * A `dict` of names to variables: The keys are the names that will be
709      used to save or restore the variables in the checkpoint files.
710    * A list of variables: The variables will be keyed with their op name in
711      the checkpoint files.
712
713    For example:
714
715    ```python
716    v1 = tf.Variable(..., name='v1')
717    v2 = tf.Variable(..., name='v2')
718
719    # Pass the variables as a dict:
720    saver = tf.train.Saver({'v1': v1, 'v2': v2})
721
722    # Or pass them as a list.
723    saver = tf.train.Saver([v1, v2])
724    # Passing a list is equivalent to passing a dict with the variable op names
725    # as keys:
726    saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
727    ```
728
729    The optional `reshape` argument, if `True`, allows restoring a variable from
730    a save file where the variable had a different shape, but the same number
731    of elements and type.  This is useful if you have reshaped a variable and
732    want to reload it from an older checkpoint.
733
734    The optional `sharded` argument, if `True`, instructs the saver to shard
735    checkpoints per device.
736
737    Args:
738      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
739        names to `SaveableObject`s. If `None`, defaults to the list of all
740        saveable objects.
741      reshape: If `True`, allows restoring parameters from a checkpoint
742        where the variables have a different shape.
743      sharded: If `True`, shard the checkpoints, one per device.
744      max_to_keep: Maximum number of recent checkpoints to keep.
745        Defaults to 5.
746      keep_checkpoint_every_n_hours: How often to keep checkpoints.
747        Defaults to 10,000 hours.
748      name: String.  Optional name to use as a prefix when adding operations.
749      restore_sequentially: A `Bool`, which if true, causes restore of different
750        variables to happen sequentially within each device.  This can lower
751        memory usage when restoring very large models.
752      saver_def: Optional `SaverDef` proto to use instead of running the
753        builder. This is only useful for specialty code that wants to recreate
754        a `Saver` object for a previously built `Graph` that had a `Saver`.
755        The `saver_def` proto should be the one returned by the
756        `as_saver_def()` call of the `Saver` that was created for that `Graph`.
757      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
758        Defaults to `BulkSaverBuilder()`.
759      defer_build: If `True`, defer adding the save and restore ops to the
760        `build()` call. In that case `build()` should be called before
761        finalizing the graph or using the saver.
762      allow_empty: If `False` (default) raise an error if there are no
763        variables in the graph. Otherwise, construct the saver anyway and make
764        it a no-op.
765      write_version: controls what format to use when saving checkpoints.  It
766        also affects certain filepath matching logic.  The V2 format is the
767        recommended choice: it is much more optimized than V1 in terms of
768        memory required and latency incurred during restore.  Regardless of
769        this flag, the Saver is able to restore from both V2 and V1 checkpoints.
770      pad_step_number: if True, pads the global step number in the checkpoint
771        filepaths to some fixed width (8 by default).  This is turned off by
772        default.
773      save_relative_paths: If `True`, will write relative paths to the
774        checkpoint state file. This is needed if the user wants to copy the
775        checkpoint directory and reload from the copied directory.
776      filename: If known at graph construction time, filename used for variable
777        loading/saving.
778
779    Raises:
780      TypeError: If `var_list` is invalid.
781      ValueError: If any of the keys or values in `var_list` are not unique.
782      RuntimeError: If eager execution is enabled and`var_list` does not specify
783        a list of varialbes to save.
784
785    @compatibility(eager)
786    When eager execution is enabled, `var_list` must specify a `list` or `dict`
787    of variables to save. Otherwise, a `RuntimeError` will be raised.
788
789    Although Saver works in some cases when executing eagerly, it is
790    fragile. Please switch to `tf.train.Checkpoint` or
791    `tf.keras.Model.save_weights`, which perform a more robust object-based
792    saving. These APIs will load checkpoints written by `Saver`.
793    @end_compatibility
794    """
795    if defer_build and var_list:
796      raise ValueError(
797          "If `var_list` is provided then build cannot be deferred. "
798          "Either set defer_build=False or var_list=None.")
799    if context.executing_eagerly():
800      logging.warning(
801          "Saver is deprecated, please switch to tf.train.Checkpoint or "
802          "tf.keras.Model.save_weights for training checkpoints. When "
803          "executing eagerly variables do not necessarily have unique names, "
804          "and so the variable.name-based lookups Saver performs are "
805          "error-prone.")
806      if var_list is None:
807        raise RuntimeError(
808            "When eager execution is enabled, `var_list` must specify a list "
809            "or dict of variables to save")
810    self._var_list = var_list
811    self._reshape = reshape
812    self._sharded = sharded
813    self._max_to_keep = max_to_keep
814    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
815    self._name = name
816    self._restore_sequentially = restore_sequentially
817    self.saver_def = saver_def
818    self._builder = builder
819    self._is_built = False
820    self._allow_empty = allow_empty
821    self._is_empty = None
822    self._write_version = write_version
823    self._pad_step_number = pad_step_number
824    self._filename = filename
825    self._last_checkpoints = []
826    self._checkpoints_to_be_deleted = []
827    if context.executing_eagerly():
828      self._next_checkpoint_time = (
829          time.time() + self._keep_checkpoint_every_n_hours * 3600)
830    elif not defer_build:
831      self.build()
832    if self.saver_def:
833      self._check_saver_def()
834      self._write_version = self.saver_def.version
835    self._save_relative_paths = save_relative_paths
836    # For compatibility with object-based checkpoints, we may build a second
837    # Saver to read the renamed keys.
838    self._object_restore_saver = None
839
840  def build(self):
841    if context.executing_eagerly():
842      raise RuntimeError("Use save/restore instead of build in eager mode.")
843    self._build(self._filename, build_save=True, build_restore=True)
844
845  def _build_eager(self, checkpoint_path, build_save, build_restore):
846    self._build(
847        checkpoint_path, build_save=build_save, build_restore=build_restore)
848
849  def _build(self, checkpoint_path, build_save, build_restore):
850    """Builds saver_def."""
851    if not context.executing_eagerly():
852      if self._is_built:
853        return
854      self._is_built = True
855
856    if not self.saver_def or context.executing_eagerly():
857      if self._builder is None:
858        self._builder = BulkSaverBuilder(self._write_version)
859
860      if self._var_list is None:
861        # pylint: disable=protected-access
862        self._var_list = variables._all_saveable_objects()
863      if not self._var_list:
864        if self._allow_empty:
865          self._is_empty = True
866          return
867        else:
868          raise ValueError("No variables to save")
869      self._is_empty = False
870
871      self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
872          self._var_list,
873          reshape=self._reshape,
874          sharded=self._sharded,
875          max_to_keep=self._max_to_keep,
876          keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
877          name=self._name,
878          restore_sequentially=self._restore_sequentially,
879          filename=checkpoint_path,
880          build_save=build_save, build_restore=build_restore)
881    elif self.saver_def and self._name:
882      # Since self._name is used as a name_scope by builder(), we are
883      # overloading the use of this field to represent the "import_scope" as
884      # well.
885      self.saver_def.filename_tensor_name = ops.prepend_name_scope(
886          self.saver_def.filename_tensor_name, self._name)
887      self.saver_def.save_tensor_name = ops.prepend_name_scope(
888          self.saver_def.save_tensor_name, self._name)
889      self.saver_def.restore_op_name = ops.prepend_name_scope(
890          self.saver_def.restore_op_name, self._name)
891
892    self._check_saver_def()
893    if not context.executing_eagerly():
894      # Updates next checkpoint time.
895      # Set in __init__ when executing eagerly.
896      self._next_checkpoint_time = (
897          time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
898
899  def _check_saver_def(self):
900    if not isinstance(self.saver_def, saver_pb2.SaverDef):
901      raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
902                       self.saver_def)
903    if not context.executing_eagerly():
904      if not self.saver_def.save_tensor_name:
905        raise ValueError("saver_def must specify the save_tensor_name: %s" %
906                         str(self.saver_def))
907      if not self.saver_def.restore_op_name:
908        raise ValueError("saver_def must specify the restore_op_name: %s" %
909                         str(self.saver_def))
910
911  def _CheckpointFilename(self, p):
912    """Returns the checkpoint filename given a `(filename, time)` pair.
913
914    Args:
915      p: (filename, time) pair.
916
917    Returns:
918      Checkpoint file name.
919    """
920    name, _ = p
921    return name
922
923  def _RecordLastCheckpoint(self, latest_save_path):
924    """Manages the list of the latest checkpoints."""
925    if not self.saver_def.max_to_keep:
926      return
927    # Remove first from list if the same name was used before.
928    for p in self._last_checkpoints:
929      if latest_save_path == self._CheckpointFilename(p):
930        self._last_checkpoints.remove(p)
931    # Append new path to list
932    self._last_checkpoints.append((latest_save_path, time.time()))
933
934    # If more than max_to_keep, remove oldest.
935    if len(self._last_checkpoints) > self.saver_def.max_to_keep:
936      self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
937
938  def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
939    """Deletes old checkpoints if necessary.
940
941    `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
942    over `max_to_keep`.  They are going to be deleted.  If
943    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
944    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
945    kept for every 0.5 hours of training; if `N` is 10, an additional
946    checkpoint is kept for every 10 hours of training.
947
948    Args:
949      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
950    """
951    if self._checkpoints_to_be_deleted:
952      p = self._checkpoints_to_be_deleted.pop(0)
953      # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
954      # have reached N hours of training.
955      should_keep = p[1] > self._next_checkpoint_time
956      if should_keep:
957        self._next_checkpoint_time += (
958            self.saver_def.keep_checkpoint_every_n_hours * 3600)
959        return
960
961      # Otherwise delete the files.
962      try:
963        checkpoint_management.remove_checkpoint(
964            self._CheckpointFilename(p), self.saver_def.version,
965            meta_graph_suffix)
966      except Exception as e:  # pylint: disable=broad-except
967        logging.warning("Ignoring: %s", str(e))
968
969  def as_saver_def(self):
970    """Generates a `SaverDef` representation of this saver.
971
972    Returns:
973      A `SaverDef` proto.
974    """
975    return self.saver_def
976
977  def to_proto(self, export_scope=None):
978    """Converts this `Saver` to a `SaverDef` protocol buffer.
979
980    Args:
981      export_scope: Optional `string`. Name scope to remove.
982
983    Returns:
984      A `SaverDef` protocol buffer.
985    """
986    if export_scope is None:
987      return self.saver_def
988
989    if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
990            self.saver_def.save_tensor_name.startswith(export_scope) and
991            self.saver_def.restore_op_name.startswith(export_scope)):
992      return None
993
994    saver_def = saver_pb2.SaverDef()
995    saver_def.CopyFrom(self.saver_def)
996    saver_def.filename_tensor_name = ops.strip_name_scope(
997        saver_def.filename_tensor_name, export_scope)
998    saver_def.save_tensor_name = ops.strip_name_scope(
999        saver_def.save_tensor_name, export_scope)
1000    saver_def.restore_op_name = ops.strip_name_scope(
1001        saver_def.restore_op_name, export_scope)
1002    return saver_def
1003
1004  @staticmethod
1005  def from_proto(saver_def, import_scope=None):
1006    """Returns a `Saver` object created from `saver_def`.
1007
1008    Args:
1009      saver_def: a `SaverDef` protocol buffer.
1010      import_scope: Optional `string`. Name scope to use.
1011
1012    Returns:
1013      A `Saver` built from saver_def.
1014    """
1015    return Saver(saver_def=saver_def, name=import_scope)
1016
1017  @property
1018  def last_checkpoints(self):
1019    """List of not-yet-deleted checkpoint filenames.
1020
1021    You can pass any of the returned values to `restore()`.
1022
1023    Returns:
1024      A list of checkpoint filenames, sorted from oldest to newest.
1025    """
1026    return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
1027
1028  def set_last_checkpoints(self, last_checkpoints):
1029    """DEPRECATED: Use set_last_checkpoints_with_time.
1030
1031    Sets the list of old checkpoint filenames.
1032
1033    Args:
1034      last_checkpoints: A list of checkpoint filenames.
1035
1036    Raises:
1037      AssertionError: If last_checkpoints is not a list.
1038    """
1039    assert isinstance(last_checkpoints, list)
1040    # We use a timestamp of +inf so that this checkpoint will never be
1041    # deleted.  This is both safe and backwards compatible to a previous
1042    # version of the code which used s[1] as the "timestamp".
1043    self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
1044
1045  def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
1046    """Sets the list of old checkpoint filenames and timestamps.
1047
1048    Args:
1049      last_checkpoints_with_time: A list of tuples of checkpoint filenames and
1050        timestamps.
1051
1052    Raises:
1053      AssertionError: If last_checkpoints_with_time is not a list.
1054    """
1055    assert isinstance(last_checkpoints_with_time, list)
1056    self._last_checkpoints = last_checkpoints_with_time
1057
1058  def recover_last_checkpoints(self, checkpoint_paths):
1059    """Recovers the internal saver state after a crash.
1060
1061    This method is useful for recovering the "self._last_checkpoints" state.
1062
1063    Globs for the checkpoints pointed to by `checkpoint_paths`.  If the files
1064    exist, use their mtime as the checkpoint timestamp.
1065
1066    Args:
1067      checkpoint_paths: a list of checkpoint paths.
1068    """
1069    mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
1070    self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
1071
1072  def save(self,
1073           sess,
1074           save_path,
1075           global_step=None,
1076           latest_filename=None,
1077           meta_graph_suffix="meta",
1078           write_meta_graph=True,
1079           write_state=True,
1080           strip_default_attrs=False,
1081           save_debug_info=False):
1082    # pylint: disable=line-too-long
1083    """Saves variables.
1084
1085    This method runs the ops added by the constructor for saving variables.
1086    It requires a session in which the graph was launched.  The variables to
1087    save must also have been initialized.
1088
1089    The method returns the path prefix of the newly created checkpoint files.
1090    This string can be passed directly to a call to `restore()`.
1091
1092    Args:
1093      sess: A Session to use to save the variables.
1094      save_path: String.  Prefix of filenames created for the checkpoint.
1095      global_step: If provided the global step number is appended to
1096        `save_path` to create the checkpoint filenames. The optional argument
1097        can be a `Tensor`, a `Tensor` name or an integer.
1098      latest_filename: Optional name for the protocol buffer file that will
1099        contains the list of most recent checkpoints.  That file,
1100        kept in the same directory as the checkpoint files, is automatically
1101        managed by the saver to keep track of recent checkpoints.  Defaults to
1102        'checkpoint'.
1103      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1104      write_meta_graph: `Boolean` indicating whether or not to write the meta
1105        graph file.
1106      write_state: `Boolean` indicating whether or not to write the
1107        `CheckpointStateProto`.
1108      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1109        removed from the NodeDefs. For a detailed guide, see
1110        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1111      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1112        which in the same directory of save_path and with `_debug` added before
1113        the file extension. This is only enabled when `write_meta_graph` is
1114        `True`
1115
1116    Returns:
1117      A string: path prefix used for the checkpoint files.  If the saver is
1118        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
1119        is the number of shards created.
1120      If the saver is empty, returns None.
1121
1122    Raises:
1123      TypeError: If `sess` is not a `Session`.
1124      ValueError: If `latest_filename` contains path components, or if it
1125        collides with `save_path`.
1126      RuntimeError: If save and restore ops weren't built.
1127    """
1128    # pylint: enable=line-too-long
1129    if not self._is_built and not context.executing_eagerly():
1130      raise RuntimeError(
1131          "`build()` should be called before save if defer_build==True")
1132    if latest_filename is None:
1133      latest_filename = "checkpoint"
1134    if self._write_version != saver_pb2.SaverDef.V2:
1135      logging.warning("*******************************************************")
1136      logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
1137      logging.warning("Consider switching to the more efficient V2 format:")
1138      logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
1139      logging.warning("now on by default.")
1140      logging.warning("*******************************************************")
1141
1142    if os.path.split(latest_filename)[0]:
1143      raise ValueError("'latest_filename' must not contain path components")
1144
1145    if global_step is not None:
1146      if not isinstance(global_step, compat.integral_types):
1147        global_step = training_util.global_step(sess, global_step)
1148      checkpoint_file = "%s-%d" % (save_path, global_step)
1149      if self._pad_step_number:
1150        # Zero-pads the step numbers, so that they are sorted when listed.
1151        checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
1152    else:
1153      checkpoint_file = save_path
1154      if os.path.basename(
1155          save_path) == latest_filename and not self._sharded:
1156        # Guard against collision between data file and checkpoint state file.
1157        raise ValueError(
1158            "'latest_filename' collides with 'save_path': '%s' and '%s'" %
1159            (latest_filename, save_path))
1160
1161    if (not context.executing_eagerly() and
1162        not isinstance(sess, session.SessionInterface)):
1163      raise TypeError("'sess' must be a Session; %s" % sess)
1164
1165    save_path_parent = os.path.dirname(save_path)
1166    if not self._is_empty:
1167      try:
1168        if context.executing_eagerly():
1169          self._build_eager(
1170              checkpoint_file, build_save=True, build_restore=False)
1171          model_checkpoint_path = self.saver_def.save_tensor_name
1172        else:
1173          model_checkpoint_path = sess.run(
1174              self.saver_def.save_tensor_name,
1175              {self.saver_def.filename_tensor_name: checkpoint_file})
1176
1177        model_checkpoint_path = compat.as_str(model_checkpoint_path)
1178        if write_state:
1179          self._RecordLastCheckpoint(model_checkpoint_path)
1180          checkpoint_management.update_checkpoint_state_internal(
1181              save_dir=save_path_parent,
1182              model_checkpoint_path=model_checkpoint_path,
1183              all_model_checkpoint_paths=self.last_checkpoints,
1184              latest_filename=latest_filename,
1185              save_relative_paths=self._save_relative_paths)
1186          self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
1187      except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
1188        if not gfile.IsDirectory(save_path_parent):
1189          exc = ValueError(
1190              "Parent directory of {} doesn't exist, can't save.".format(
1191                  save_path))
1192        raise exc
1193
1194    if write_meta_graph:
1195      meta_graph_filename = checkpoint_management.meta_graph_filename(
1196          checkpoint_file, meta_graph_suffix=meta_graph_suffix)
1197      if not context.executing_eagerly():
1198        with sess.graph.as_default():
1199          self.export_meta_graph(
1200              meta_graph_filename, strip_default_attrs=strip_default_attrs,
1201              save_debug_info=save_debug_info)
1202
1203    if self._is_empty:
1204      return None
1205    else:
1206      return model_checkpoint_path
1207
1208  def export_meta_graph(self,
1209                        filename=None,
1210                        collection_list=None,
1211                        as_text=False,
1212                        export_scope=None,
1213                        clear_devices=False,
1214                        clear_extraneous_savers=False,
1215                        strip_default_attrs=False,
1216                        save_debug_info=False):
1217    # pylint: disable=line-too-long
1218    """Writes `MetaGraphDef` to save_path/filename.
1219
1220    Args:
1221      filename: Optional meta_graph filename including the path.
1222      collection_list: List of string keys to collect.
1223      as_text: If `True`, writes the meta_graph as an ASCII proto.
1224      export_scope: Optional `string`. Name scope to remove.
1225      clear_devices: Whether or not to clear the device field for an `Operation`
1226        or `Tensor` during export.
1227      clear_extraneous_savers: Remove any Saver-related information from the
1228        graph (both Save/Restore ops and SaverDefs) that are not associated
1229        with this Saver.
1230      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1231        removed from the NodeDefs. For a detailed guide, see
1232        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1233      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1234        which in the same directory of filename and with `_debug` added before
1235        the file extension.
1236
1237    Returns:
1238      A `MetaGraphDef` proto.
1239    """
1240    # pylint: enable=line-too-long
1241    return export_meta_graph(
1242        filename=filename,
1243        graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
1244        saver_def=self.saver_def,
1245        collection_list=collection_list,
1246        as_text=as_text,
1247        export_scope=export_scope,
1248        clear_devices=clear_devices,
1249        clear_extraneous_savers=clear_extraneous_savers,
1250        strip_default_attrs=strip_default_attrs,
1251        save_debug_info=save_debug_info)
1252
1253  def restore(self, sess, save_path):
1254    """Restores previously saved variables.
1255
1256    This method runs the ops added by the constructor for restoring variables.
1257    It requires a session in which the graph was launched.  The variables to
1258    restore do not have to have been initialized, as restoring is itself a way
1259    to initialize variables.
1260
1261    The `save_path` argument is typically a value previously returned from a
1262    `save()` call, or a call to `latest_checkpoint()`.
1263
1264    Args:
1265      sess: A `Session` to use to restore the parameters. None in eager mode.
1266      save_path: Path where parameters were previously saved.
1267
1268    Raises:
1269      ValueError: If save_path is None or not a valid checkpoint.
1270    """
1271    if self._is_empty:
1272      return
1273    if save_path is None:
1274      raise ValueError("Can't load save_path when it is None.")
1275
1276    if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
1277      raise ValueError("The passed save_path is not a valid checkpoint: "
1278                       + compat.as_text(save_path))
1279
1280    logging.info("Restoring parameters from %s", compat.as_text(save_path))
1281    try:
1282      if context.executing_eagerly():
1283        self._build_eager(save_path, build_save=False, build_restore=True)
1284      else:
1285        sess.run(self.saver_def.restore_op_name,
1286                 {self.saver_def.filename_tensor_name: save_path})
1287    except errors.NotFoundError as err:
1288      # There are three common conditions that might cause this error:
1289      # 0. The file is missing. We ignore here, as this is checked above.
1290      # 1. This is an object-based checkpoint trying name-based loading.
1291      # 2. The graph has been altered and a variable or other name is missing.
1292
1293      # 1. The checkpoint would not be loaded successfully as is. Try to parse
1294      # it as an object-based checkpoint.
1295      try:
1296        names_to_keys = object_graph_key_mapping(save_path)
1297      except errors.NotFoundError:
1298        # 2. This is not an object-based checkpoint, which likely means there
1299        # is a graph mismatch. Re-raise the original error with
1300        # a helpful message (b/110263146)
1301        raise _wrap_restore_error_with_msg(
1302            err, "a Variable name or other graph key that is missing")
1303
1304      # This is an object-based checkpoint. We'll print a warning and then do
1305      # the restore.
1306      logging.warning(
1307          "Restoring an object-based checkpoint using a name-based saver. This "
1308          "may be somewhat fragile, and will re-build the Saver. Instead, "
1309          "consider loading object-based checkpoints using "
1310          "tf.train.Checkpoint().")
1311      self._object_restore_saver = saver_from_object_based_checkpoint(
1312          checkpoint_path=save_path,
1313          var_list=self._var_list,
1314          builder=self._builder,
1315          names_to_keys=names_to_keys,
1316          cached_saver=self._object_restore_saver)
1317      self._object_restore_saver.restore(sess=sess, save_path=save_path)
1318    except errors.InvalidArgumentError as err:
1319      # There is a mismatch between the graph and the checkpoint being loaded.
1320      # We add a more reasonable error message here to help users (b/110263146)
1321      raise _wrap_restore_error_with_msg(
1322          err, "a mismatch between the current graph and the graph")
1323
1324  @staticmethod
1325  def _add_collection_def(meta_graph_def, key, export_scope=None):
1326    """Adds a collection to MetaGraphDef protocol buffer.
1327
1328    Args:
1329      meta_graph_def: MetaGraphDef protocol buffer.
1330      key: One of the GraphKeys or user-defined string.
1331      export_scope: Optional `string`. Name scope to remove.
1332    """
1333    meta_graph.add_collection_def(meta_graph_def, key,
1334                                  export_scope=export_scope)
1335
1336
1337@tf_export(v1=["train.import_meta_graph"])
1338def import_meta_graph(meta_graph_or_file, clear_devices=False,
1339                      import_scope=None, **kwargs):
1340  """Recreates a Graph saved in a `MetaGraphDef` proto.
1341
1342  This function takes a `MetaGraphDef` protocol buffer as input. If
1343  the argument is a file containing a `MetaGraphDef` protocol buffer ,
1344  it constructs a protocol buffer from the file content. The function
1345  then adds all the nodes from the `graph_def` field to the
1346  current graph, recreates all the collections, and returns a saver
1347  constructed from the `saver_def` field.
1348
1349  In combination with `export_meta_graph()`, this function can be used to
1350
1351  * Serialize a graph along with other Python objects such as `QueueRunner`,
1352    `Variable` into a `MetaGraphDef`.
1353
1354  * Restart training from a saved graph and checkpoints.
1355
1356  * Run inference from a saved graph and checkpoints.
1357
1358  ```Python
1359  ...
1360  # Create a saver.
1361  saver = tf.train.Saver(...variables...)
1362  # Remember the training_op we want to run by adding it to a collection.
1363  tf.add_to_collection('train_op', train_op)
1364  sess = tf.Session()
1365  for step in xrange(1000000):
1366      sess.run(train_op)
1367      if step % 1000 == 0:
1368          # Saves checkpoint, which by default also exports a meta_graph
1369          # named 'my-model-global_step.meta'.
1370          saver.save(sess, 'my-model', global_step=step)
1371  ```
1372
1373  Later we can continue training from this saved `meta_graph` without building
1374  the model from scratch.
1375
1376  ```Python
1377  with tf.Session() as sess:
1378    new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
1379    new_saver.restore(sess, 'my-save-dir/my-model-10000')
1380    # tf.get_collection() returns a list. In this example we only want the
1381    # first one.
1382    train_op = tf.get_collection('train_op')[0]
1383    for step in xrange(1000000):
1384      sess.run(train_op)
1385  ```
1386
1387  NOTE: Restarting training from saved `meta_graph` only works if the
1388  device assignments have not changed.
1389
1390  Example 2:
1391  Variables, placeholders, and independent operations can also be stored, as
1392  shown in the following example.
1393
1394  ```Python
1395  # Saving contents and operations.
1396  v1 = tf.placeholder(tf.float32, name="v1")
1397  v2 = tf.placeholder(tf.float32, name="v2")
1398  v3 = tf.mul(v1, v2)
1399  vx = tf.Variable(10.0, name="vx")
1400  v4 = tf.add(v3, vx, name="v4")
1401  saver = tf.train.Saver([vx])
1402  sess = tf.Session()
1403  sess.run(tf.initialize_all_variables())
1404  sess.run(vx.assign(tf.add(vx, vx)))
1405  result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
1406  print(result)
1407  saver.save(sess, "./model_ex1")
1408  ```
1409
1410  Later this model can be restored and contents loaded.
1411
1412  ```Python
1413  # Restoring variables and running operations.
1414  saver = tf.train.import_meta_graph("./model_ex1.meta")
1415  sess = tf.Session()
1416  saver.restore(sess, "./model_ex1")
1417  result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
1418  print(result)
1419  ```
1420
1421  Args:
1422    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
1423      the path) containing a `MetaGraphDef`.
1424    clear_devices: Whether or not to clear the device field for an `Operation`
1425      or `Tensor` during import.
1426    import_scope: Optional `string`. Name scope to add. Only used when
1427      initializing from protocol buffer.
1428    **kwargs: Optional keyed arguments.
1429
1430  Returns:
1431    A saver constructed from `saver_def` in `MetaGraphDef` or None.
1432
1433    A None value is returned if no variables exist in the `MetaGraphDef`
1434    (i.e., there are no variables to restore).
1435
1436  Raises:
1437    RuntimeError: If called with eager execution enabled.
1438
1439  @compatibility(eager)
1440  Exporting/importing meta graphs is not supported. No graph exists when eager
1441  execution is enabled.
1442  @end_compatibility
1443  """  # pylint: disable=g-doc-exception
1444  return _import_meta_graph_with_return_elements(
1445      meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
1446
1447
1448def _import_meta_graph_with_return_elements(
1449    meta_graph_or_file, clear_devices=False, import_scope=None,
1450    return_elements=None, **kwargs):
1451  """Import MetaGraph, and return both a saver and returned elements."""
1452  if context.executing_eagerly():
1453    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1454                       "eager execution is enabled. No graph exists when eager "
1455                       "execution is enabled.")
1456  if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
1457    meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
1458  else:
1459    meta_graph_def = meta_graph_or_file
1460
1461  imported_vars, imported_return_elements = (
1462      meta_graph.import_scoped_meta_graph_with_return_elements(
1463          meta_graph_def,
1464          clear_devices=clear_devices,
1465          import_scope=import_scope,
1466          return_elements=return_elements,
1467          **kwargs))
1468
1469  saver = _create_saver_from_imported_meta_graph(
1470      meta_graph_def, import_scope, imported_vars)
1471  return saver, imported_return_elements
1472
1473
1474def _create_saver_from_imported_meta_graph(
1475    meta_graph_def, import_scope, imported_vars):
1476  """Return a saver for restoring variable values to an imported MetaGraph."""
1477  if meta_graph_def.HasField("saver_def"):
1478    # Infer the scope that is prepended by `import_scoped_meta_graph`.
1479    scope = import_scope
1480    var_names = list(imported_vars.keys())
1481    if var_names:
1482      sample_key = var_names[0]
1483      sample_var = imported_vars[sample_key]
1484      scope = sample_var.name[:-len(sample_key)]
1485
1486    return Saver(saver_def=meta_graph_def.saver_def, name=scope)
1487  else:
1488    if variables._all_saveable_objects(scope=import_scope):  # pylint: disable=protected-access
1489      # Return the default saver instance for all graph variables.
1490      return Saver()
1491    else:
1492      # If no graph variables exist, then a Saver cannot be constructed.
1493      logging.info("Saver not created because there are no variables in the"
1494                   " graph to restore")
1495      return None
1496
1497
1498@tf_export(v1=["train.export_meta_graph"])
1499def export_meta_graph(filename=None,
1500                      meta_info_def=None,
1501                      graph_def=None,
1502                      saver_def=None,
1503                      collection_list=None,
1504                      as_text=False,
1505                      graph=None,
1506                      export_scope=None,
1507                      clear_devices=False,
1508                      clear_extraneous_savers=False,
1509                      strip_default_attrs=False,
1510                      save_debug_info=False,
1511                      **kwargs):
1512  # pylint: disable=line-too-long
1513  """Returns `MetaGraphDef` proto. Optionally writes it to filename.
1514
1515  This function exports the graph, saver, and collection objects into
1516  `MetaGraphDef` protocol buffer with the intention of it being imported
1517  at a later time or location to restart training, run inference, or be
1518  a subgraph.
1519
1520  Args:
1521    filename: Optional filename including the path for writing the
1522      generated `MetaGraphDef` protocol buffer.
1523    meta_info_def: `MetaInfoDef` protocol buffer.
1524    graph_def: `GraphDef` protocol buffer.
1525    saver_def: `SaverDef` protocol buffer.
1526    collection_list: List of string keys to collect.
1527    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
1528    graph: The `Graph` to export. If `None`, use the default graph.
1529    export_scope: Optional `string`. Name scope under which to extract
1530      the subgraph. The scope name will be striped from the node definitions
1531      for easy import later into new name scopes. If `None`, the whole graph
1532      is exported. graph_def and export_scope cannot both be specified.
1533    clear_devices: Whether or not to clear the device field for an `Operation`
1534      or `Tensor` during export.
1535    clear_extraneous_savers: Remove any Saver-related information from the
1536        graph (both Save/Restore ops and SaverDefs) that are not associated
1537        with the provided SaverDef.
1538    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1539      removed from the NodeDefs. For a detailed guide, see
1540      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1541    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1542      which in the same directory of filename and with `_debug` added before
1543      the file extend.
1544    **kwargs: Optional keyed arguments.
1545
1546  Returns:
1547    A `MetaGraphDef` proto.
1548
1549  Raises:
1550    ValueError: When the `GraphDef` is larger than 2GB.
1551    RuntimeError: If called with eager execution enabled.
1552
1553  @compatibility(eager)
1554  Exporting/importing meta graphs is not supported unless both `graph_def` and
1555  `graph` are provided. No graph exists when eager execution is enabled.
1556  @end_compatibility
1557  """
1558  # pylint: enable=line-too-long
1559  if context.executing_eagerly() and not (graph_def is not None and
1560                                          graph is not None):
1561    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1562                       "eager execution is enabled. No graph exists when eager "
1563                       "execution is enabled.")
1564  meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
1565      filename=filename,
1566      meta_info_def=meta_info_def,
1567      graph_def=graph_def,
1568      saver_def=saver_def,
1569      collection_list=collection_list,
1570      as_text=as_text,
1571      graph=graph,
1572      export_scope=export_scope,
1573      clear_devices=clear_devices,
1574      clear_extraneous_savers=clear_extraneous_savers,
1575      strip_default_attrs=strip_default_attrs,
1576      save_debug_info=save_debug_info,
1577      **kwargs)
1578  return meta_graph_def
1579
1580
1581def _wrap_restore_error_with_msg(err, extra_verbiage):
1582  err_msg = ("Restoring from checkpoint failed. This is most likely "
1583             "due to {} from the checkpoint. Please ensure that you "
1584             "have not altered the graph expected based on the checkpoint. "
1585             "Original error:\n\n{}").format(extra_verbiage, err.message)
1586  return err.__class__(err.node_def, err.op, err_msg)
1587
1588
1589ops.register_proto_function(
1590    ops.GraphKeys.SAVERS,
1591    proto_type=saver_pb2.SaverDef,
1592    to_proto=Saver.to_proto,
1593    from_proto=Saver.from_proto)
1594
1595
1596def object_graph_key_mapping(checkpoint_path):
1597  """Return name to key mappings from the checkpoint.
1598
1599  Args:
1600    checkpoint_path: string, path to object-based checkpoint
1601
1602  Returns:
1603    Dictionary mapping tensor names to checkpoint keys.
1604  """
1605  reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
1606  object_graph_string = reader.get_tensor(
1607      trackable.OBJECT_GRAPH_PROTO_KEY)
1608  object_graph_proto = (
1609      trackable_object_graph_pb2.TrackableObjectGraph())
1610  object_graph_proto.ParseFromString(object_graph_string)
1611  names_to_keys = {}
1612  for node in object_graph_proto.nodes:
1613    for attribute in node.attributes:
1614      names_to_keys[attribute.full_name] = attribute.checkpoint_key
1615  return names_to_keys
1616
1617
1618def saver_from_object_based_checkpoint(
1619    checkpoint_path, var_list=None, builder=None, names_to_keys=None,
1620    cached_saver=None):
1621  """Return a `Saver` which reads from an object-based checkpoint.
1622
1623  This function validates that all variables in the variables list are remapped
1624  in the object-based checkpoint (or `names_to_keys` dict if provided). A
1625  saver will be created with the list of remapped variables.
1626
1627  The `cached_saver` argument allows the user to pass in a previously created
1628  saver, so multiple `saver.restore()` calls don't pollute the graph when graph
1629  building. This assumes that keys are consistent, meaning that the
1630    1) `checkpoint_path` checkpoint, and
1631    2) checkpoint used to create the `cached_saver`
1632  are the same type of object-based checkpoint. If this argument is set, this
1633  function will simply validate that all variables have been remapped by the
1634  checkpoint at `checkpoint_path`.
1635
1636  Note that in general, `tf.train.Checkpoint` should be used to restore/save an
1637  object-based checkpoint.
1638
1639  Args:
1640    checkpoint_path: string, path to object-based checkpoint
1641    var_list: list of `Variables` that appear in the checkpoint. If `None`,
1642      `var_list` will be set to all saveable objects.
1643    builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder`
1644      will be created.
1645    names_to_keys: dict mapping string tensor names to checkpooint keys. If
1646      `None`, this dict will be generated from the checkpoint file.
1647    cached_saver: Cached `Saver` object with remapped variables.
1648
1649  Returns:
1650    `Saver` with remapped variables for reading from an object-based checkpoint.
1651
1652  Raises:
1653    ValueError if the checkpoint provided is not an object-based checkpoint.
1654    NotFoundError: If one of the variables in `var_list` can not be found in the
1655      checkpoint. This could mean the checkpoint or `names_to_keys` mapping is
1656      missing the variable.
1657  """
1658  if names_to_keys is None:
1659    try:
1660      names_to_keys = object_graph_key_mapping(checkpoint_path)
1661    except errors.NotFoundError:
1662      raise ValueError("Checkpoint in %s not an object-based checkpoint."
1663                       % checkpoint_path)
1664  if var_list is None:
1665    var_list = variables._all_saveable_objects()  # pylint: disable=protected-access
1666  if builder is None:
1667    builder = BulkSaverBuilder()
1668
1669  saveables = saveable_object_util.validate_and_slice_inputs(var_list)
1670  current_names = set()
1671  for saveable in saveables:
1672    for spec in saveable.specs:
1673      current_names.add(spec.name)
1674  previous_names = set(names_to_keys.keys())
1675  missing_names = current_names - previous_names
1676  if missing_names:
1677    extra_names = previous_names - current_names
1678    intersecting_names = previous_names.intersection(current_names)
1679    raise errors.NotFoundError(
1680        None, None,
1681        message=(
1682            "\n\nExisting variables not in the checkpoint: %s\n\n"
1683            "Variables names when this checkpoint was written which don't "
1684            "exist now: %s\n\n"
1685            "(%d variable name(s) did match)\n\n"
1686            "Could not find some variables in the checkpoint (see names "
1687            "above). Saver was attempting to load an object-based checkpoint "
1688            "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) "
1689            "using variable names. If the checkpoint was written with eager "
1690            "execution enabled, it's possible that variable names have "
1691            "changed (for example missing a '_1' suffix). It's also "
1692            "possible that there are new variables which did not exist "
1693            "when the checkpoint was written. You can construct a "
1694            "Saver(var_list=...) with only the variables which previously "
1695            "existed, and if variable names have changed you may need to "
1696            "make this a dictionary with the old names as keys. If you're "
1697            "using an Estimator, you'll need to return a tf.train.Saver "
1698            "inside a tf.train.Scaffold from your model_fn.")
1699        % (", ".join(sorted(missing_names)), ", ".join(sorted(extra_names)),
1700           len(intersecting_names)))
1701  for saveable in saveables:
1702    for spec in saveable.specs:
1703      spec.name = names_to_keys[spec.name]
1704  if cached_saver is None:
1705    return Saver(saveables)
1706  return cached_saver
1707