• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities for saving/loading Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22import os
23import weakref
24
25import six
26
27from tensorflow.core.protobuf import trackable_object_graph_pb2
28from tensorflow.python.client import session as session_lib
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors_impl
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.lib.io import file_io
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import gen_io_ops as io_ops
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.training import checkpoint_management
45from tensorflow.python.training import py_checkpoint_reader
46from tensorflow.python.training import saver as v1_saver_lib
47from tensorflow.python.training.saving import functional_saver
48from tensorflow.python.training.saving import saveable_object_util
49from tensorflow.python.training.tracking import base
50from tensorflow.python.training.tracking import data_structures
51from tensorflow.python.training.tracking import graph_view as graph_view_lib
52from tensorflow.python.training.tracking import tracking
53from tensorflow.python.util import compat
54from tensorflow.python.util import deprecation
55from tensorflow.python.util import lazy_loader
56from tensorflow.python.util import object_identity
57from tensorflow.python.util import tf_contextlib
58from tensorflow.python.util.tf_export import tf_export
59
60
61# Loaded lazily due to a circular dependency.
62keras_backend = lazy_loader.LazyLoader(
63    "keras_backend", globals(),
64    "tensorflow.python.keras.backend")
65
66
67def get_session():
68  # Prefer TF's default session since get_session from Keras has side-effects.
69  session = ops.get_default_session()
70  if session is None:
71    session = keras_backend.get_session()
72  return session
73
74
75class _ObjectGraphProtoPrettyPrinter(object):
76  """Lazily traverses an object graph proto to pretty print names.
77
78  If no calls to `node_names` are made this object has no performance
79  overhead. On the other hand, it will only traverse the object graph once, so
80  repeated naming is cheap after the first.
81  """
82
83  def __init__(self, object_graph_proto):
84    self._object_graph_proto = object_graph_proto
85    self._node_name_cache = None
86
87  @property
88  def node_names(self):
89    """Lazily creates a mapping from node id to ("path", "to", "root")."""
90    if self._node_name_cache is not None:
91      return self._node_name_cache
92    path_to_root = {}
93    path_to_root[0] = ("(root)",)
94    to_visit = collections.deque([0])
95    while to_visit:
96      node_id = to_visit.popleft()
97      obj = self._object_graph_proto.nodes[node_id]
98      for child in obj.children:
99        if child.node_id not in path_to_root:
100          path_to_root[child.node_id] = (
101              path_to_root[node_id] + (child.local_name,))
102          to_visit.append(child.node_id)
103
104    node_names = {}
105    for node_id, path_to_root in path_to_root.items():
106      node_names[node_id] = ".".join(path_to_root)
107
108    for node_id, node in enumerate(self._object_graph_proto.nodes):
109      for slot_reference in node.slot_variables:
110        node_names[slot_reference.slot_variable_node_id] = (
111            "{}'s state '{}' for {}".format(
112                node_names[node_id], slot_reference.slot_name,
113                node_names[slot_reference.original_variable_node_id]))
114    self._node_name_cache = node_names
115    return node_names
116
117
118class _CheckpointRestoreCoordinatorDeleter(object):
119  """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
120
121  def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
122               unused_attributes):
123    self.expect_partial = expect_partial
124    self.object_graph_proto = object_graph_proto
125    self.matched_proto_ids = matched_proto_ids
126    self.unused_attributes = unused_attributes
127
128  def set_expect_partial(self, expect_partial):
129    self.expect_partial = expect_partial
130
131  def __del__(self):
132    if self.expect_partial:
133      return
134    if logging is None:
135      # The logging module may have been unloaded when __del__ is called.
136      log_fn = print
137    else:
138      log_fn = logging.warning
139    printed_warning = False
140    pretty_printer = _ObjectGraphProtoPrettyPrinter(self.object_graph_proto)
141    for node_id in range(len(self.object_graph_proto.nodes)):
142      if node_id not in self.matched_proto_ids:
143        log_fn("Unresolved object in checkpoint: {}"
144               .format(pretty_printer.node_names[node_id]))
145        printed_warning = True
146    for node_id, attribute_name in self.unused_attributes.items():
147      log_fn(("Unused attribute in object {}: {}"
148              .format(pretty_printer.node_names[node_id], attribute_name)))
149      printed_warning = True
150    if printed_warning:
151      log_fn(
152          "A checkpoint was restored (e.g. tf.train.Checkpoint.restore or "
153          "tf.keras.Model.load_weights) but not all checkpointed values were "
154          "used. See above for specific issues. Use expect_partial() on the "
155          "load status object, e.g. "
156          "tf.train.Checkpoint.restore(...).expect_partial(), to silence these "
157          "warnings, or use assert_consumed() to make the check explicit. See "
158          "https://www.tensorflow.org/guide/checkpoint#loading_mechanics"
159          " for details.")
160
161
162class _CheckpointRestoreCoordinator(object):
163  """Holds the status of an object-based checkpoint load."""
164
165  def __init__(self, object_graph_proto, save_path, save_path_tensor,
166               restore_op_cache, graph_view):
167    """Specify the checkpoint being loaded.
168
169    Args:
170      object_graph_proto: The TrackableObjectGraph protocol buffer associated
171        with this checkpoint.
172      save_path: A string, the path to the checkpoint, as returned by
173        `tf.train.latest_checkpoint`.
174      save_path_tensor: A string `Tensor` which contains or will be fed the save
175        path.
176      restore_op_cache: A dictionary shared between
177        `_CheckpointRestoreCoordinator`s for the same Python objects, used to
178        look up restore ops by name to avoid re-creating them across multiple
179        `restore()` calls.
180      graph_view: A graph_view_lib.ObjectGraphView object for the restored
181        objects.
182    """
183    self.object_graph_proto = object_graph_proto
184    self.restore_uid = ops.uid()
185    # Maps from proto ids to lists of attributes which were in the checkpoint
186    # but not loaded into any object, for error checking.
187    self.unused_attributes = {}
188    # Dictionary mapping from an id in the protocol buffer flat array to
189    # Trackable Python objects. This mapping may be deferred if a
190    # checkpoint is restored before all dependencies have been tracked. Uses
191    # weak references so that partial restorations don't create reference cycles
192    # (as objects with deferred dependencies will generally have references to
193    # this object).
194    self.object_by_proto_id = weakref.WeakValueDictionary()
195    self.matched_proto_ids = set()
196    # A set of all Python objects we've seen as dependencies, even if we didn't
197    # use them (for example because of inconsistent references when
198    # loading). Used to make status assertions fail when loading checkpoints
199    # that don't quite match.
200    self.all_python_objects = object_identity.ObjectIdentityWeakSet()
201    self.save_path_tensor = save_path_tensor
202    self.save_path_string = save_path
203    self.dtype_map = py_checkpoint_reader.NewCheckpointReader(
204        save_path).get_variable_to_dtype_map()
205    # A NewCheckpointReader for the most recent checkpoint, for streaming Python
206    # state restoration.
207    # When graph building, contains a list of ops to run to restore objects from
208    # this checkpoint.
209    self.restore_ops = []
210    self.restore_ops_by_name = restore_op_cache
211    self.graph_view = graph_view
212    self.new_restore_ops_callback = None
213    # A mapping from optimizer proto ids to lists of slot variables to be
214    # restored when the optimizer is tracked. Only includes slot variables whose
215    # regular variables have already been created, and only for optimizer
216    # objects which have not yet been created/tracked.
217    self.deferred_slot_restorations = {}
218    # A mapping from variable proto ids to lists of slot variables to be
219    # restored when the variable is created/tracked. These get shifted over to
220    # deferred_slot_restorations if the optimizer hasn't been created when that
221    # happens.
222    self.slot_restorations = {}
223    # Controls whether errors are printed in __del__ if some objects did not
224    # match.
225    self.expect_partial_attr = False
226    for node_index, node in enumerate(self.object_graph_proto.nodes):
227      for slot_reference in node.slot_variables:
228        # `node` refers to an `Optimizer`, since only these have slot variables.
229        self.slot_restorations.setdefault(
230            slot_reference.original_variable_node_id, []).append(
231                base._SlotVariableRestoration(  # pylint: disable=protected-access
232                    optimizer_id=node_index,
233                    slot_variable_id=slot_reference.slot_variable_node_id,
234                    slot_name=slot_reference.slot_name))
235
236    self._deleter = _CheckpointRestoreCoordinatorDeleter(
237        self.expect_partial_attr,
238        self.object_graph_proto,
239        self.matched_proto_ids,
240        self.unused_attributes)
241
242  @property
243  def expect_partial(self):
244    return self.expect_partial_attr
245
246  @expect_partial.setter
247  def expect_partial(self, expect_partial):
248    self.expect_partial_attr = expect_partial
249    self._deleter.set_expect_partial(expect_partial)
250
251  def new_restore_ops(self, new_ops):
252    self.restore_ops.extend(new_ops)
253    if self.new_restore_ops_callback:
254      self.new_restore_ops_callback(new_ops)  # pylint: disable=not-callable
255
256  def restore_saveables(self, tensor_saveables, python_saveables):
257    """Run or build restore operations for SaveableObjects.
258
259    Args:
260      tensor_saveables: `SaveableObject`s which correspond to Tensors.
261      python_saveables: `PythonStateSaveable`s which correspond to Python
262        values.
263
264    Returns:
265      When graph building, a list of restore operations, either cached or newly
266      created, to restore `tensor_saveables`.
267    """
268    restore_ops = []
269    # Eagerly run restorations for Python state.
270    reader = None
271    for saveable in python_saveables:
272      if reader is None:
273        # Lazily create the NewCheckpointReader, since this requires file access
274        # and we may not have any Python saveables.
275        reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
276      spec_names = [spec.name for spec in saveable.specs]
277      saveable.python_restore([reader.get_tensor(name) for name in spec_names])
278
279    # If we have new SaveableObjects, extract and cache restore ops.
280    if tensor_saveables:
281      validated_saveables = saveable_object_util.validate_and_slice_inputs(
282          tensor_saveables)
283      validated_names = set(saveable.name for saveable in validated_saveables)
284      if set(tensor_saveables.keys()) != validated_names:
285        raise AssertionError(
286            ("Saveable keys changed when validating. Got back %s, was "
287             "expecting %s") % (tensor_saveables.keys(), validated_names))
288      new_restore_ops = functional_saver.MultiDeviceSaver(
289          validated_saveables).restore(self.save_path_tensor)
290      if not context.executing_eagerly():
291        for name, restore_op in sorted(new_restore_ops.items()):
292          restore_ops.append(restore_op)
293          assert name not in self.restore_ops_by_name
294          self.restore_ops_by_name[name] = restore_op
295    return restore_ops
296
297
298class _NameBasedRestoreCoordinator(object):
299  """Keeps the status of a name-based checkpoint restore."""
300
301  def __init__(self, save_path, dtype_map=None):
302    self.save_path = save_path
303    self.dtype_map = dtype_map
304    # A map from trackable objects to unused attribute names. We don't have
305    # proto IDs when doing a name-based restore, so the map keys differ from
306    # those in _CheckpointRestoreCoordinator.
307    self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary()
308    self.restore_uid = ops.uid()
309
310  def globally_named_object_attributes(self, trackable):
311    """Create globally named SaveableObjects from attributes.
312
313    If an object's attribute has no global name specified (default construction
314    for the SaveableObject factory), records the failure in
315    `self.unused_attributes` (which can then be used to make status assertions
316    fail; see `NameBasedSaverStatus`).
317
318    Args:
319      trackable: An object to save.
320
321    Yields:
322      SaveableObjects for `trackable`'s attributes.
323    """
324    for attribute_name, saveable_factory in (
325        trackable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
326      if callable(saveable_factory):
327        try:
328          # This saveable object factory does not have a default name= argument,
329          # which means there's no way to save/restore it using a name-based
330          # checkpoint. Ignore the error now and make sure assert_consumed()
331          # fails.
332          saveable = saveable_factory()
333        except TypeError:
334          # Even if we can't name this object, we should construct it and check
335          # whether it's optional to restore it. If it's optional we don't need
336          # to make assertions fail.
337          if not saveable_factory("").optional_restore:
338            self.unused_attributes.setdefault(trackable,
339                                              []).append(attribute_name)
340          continue
341      else:
342        saveable = saveable_factory
343      names_to_saveables = saveable_object_util.op_list_to_dict(
344          [saveable], convert_variable_to_tensor=False)
345      for name, op in names_to_saveables.items():
346        for saveable_object in saveable_object_util.saveable_objects_for_op(
347            op=op, name=name):
348          yield saveable_object
349
350  def eager_restore(self, trackable):
351    """Runs restore ops for `trackable`'s attributes."""
352    # When graph building, we don't add any restore ops to the graph until
353    # run_restore_ops/initialize_or_restore on the status object for name-based
354    # checkpoints.
355    assert context.executing_eagerly()
356    for saveable in self.globally_named_object_attributes(trackable):
357      restored_tensors = []
358      tensor_missing = False
359      for spec in saveable.specs:
360        if spec.name in self.dtype_map:
361          with ops.device("cpu:0"):
362            restored, = io_ops.restore_v2(
363                prefix=self.save_path,
364                tensor_names=[spec.name],
365                shape_and_slices=[""],
366                dtypes=[self.dtype_map[spec.name]],
367                name="%s_checkpoint_read" % (spec.name,))
368          restored_tensors.append(array_ops.identity(restored))
369        else:
370          tensor_missing = True
371
372      if tensor_missing:
373        # Record that this variable didn't match so assertions will fail.
374        self.unused_attributes.setdefault(trackable, []).append(saveable.name)
375      else:
376        # Ignores values missing from the checkpoint, as with object-based
377        # restore. Status assertions can be used to check exact matches,
378        # although it's unlikely to ever happen for name-based checkpoints.
379        saveable.restore(
380            restored_tensors=restored_tensors, restored_shapes=None)
381
382
383# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
384# or consolidating the implementation with get_variable.
385def _default_getter(name,
386                    shape,
387                    dtype,
388                    initializer=None,
389                    partition_info=None,
390                    **kwargs):
391  """A pared-down version of get_variable which does not reuse variables."""
392  dtype = dtypes.as_dtype(dtype)
393  shape_object = tensor_shape.as_shape(shape)
394  with ops.init_scope():
395    if initializer is None:
396      initializer, initializing_from_value = (
397          variable_scope._get_default_variable_store()._get_default_initializer(  # pylint: disable=protected-access
398              name=name,
399              shape=shape_object,
400              dtype=dtype))
401    else:
402      initializing_from_value = not callable(initializer)
403    # Same logic as get_variable
404    variable_dtype = dtype.base_dtype
405    if initializing_from_value:
406      if shape is not None:
407        raise ValueError("If initializer is a constant, do not specify shape.")
408      initial_value = initializer
409    else:
410      # Instantiate initializer if provided initializer is a type object.
411      if isinstance(initializer, type(init_ops.Initializer)):
412        initializer = initializer(dtype=dtype)
413
414      def initial_value():
415        return initializer(
416            shape_object.as_list(), dtype=dtype, partition_info=partition_info)
417
418    return variables.VariableV1(
419        initial_value=initial_value,
420        name=name,
421        dtype=variable_dtype,
422        use_resource=True,
423        **kwargs)
424
425
426def add_variable(trackable,
427                 name,
428                 shape=None,
429                 dtype=dtypes.float32,
430                 initializer=None,
431                 trainable=True):
432  """Add a variable to a Trackable with no scope influence."""
433  return trackable._add_variable_with_custom_getter(  # pylint: disable=protected-access
434      name=name,
435      shape=shape,
436      dtype=dtype,
437      initializer=initializer,
438      getter=_default_getter,
439      trainable=trainable)
440
441
442def object_metadata(save_path):
443  """Retrieves information about the objects in a checkpoint.
444
445  Example usage:
446
447  ```python
448  object_graph = tf.contrib.checkpoint.object_metadata(
449      tf.train.latest_checkpoint(checkpoint_directory))
450  ckpt_variable_names = set()
451  for node in object_graph.nodes:
452    for attribute in node.attributes:
453      ckpt_variable_names.add(attribute.full_name)
454  ```
455
456  Args:
457    save_path: The path to the checkpoint, as returned by `save` or
458      `tf.train.latest_checkpoint`.
459
460  Returns:
461    A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
462  Raises:
463    ValueError: If an object graph was not found in the checkpoint.
464  """
465  reader = py_checkpoint_reader.NewCheckpointReader(save_path)
466  try:
467    object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
468  except errors_impl.NotFoundError:
469    raise ValueError(
470        ('The specified checkpoint "%s" does not appear to be object-based (it '
471         'is missing the key "%s"). Likely it was created with a name-based '
472         "saver and does not contain an object dependency graph.") %
473        (save_path, base.OBJECT_GRAPH_PROTO_KEY))
474  object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
475  object_graph_proto.ParseFromString(object_graph_string)
476  return object_graph_proto
477
478
479def list_objects(root_trackable):
480  """Traverse the object graph and list all accessible objects.
481
482  Looks for `Trackable` objects which are dependencies of
483  `root_trackable`. Includes slot variables only if the variable they are
484  slotting for and the optimizer are dependencies of `root_trackable`
485  (i.e. if they would be saved with a checkpoint).
486
487  Args:
488    root_trackable: A `Trackable` object whose dependencies should be flattened.
489
490  Returns:
491    A flat list of objects.
492  """
493  return graph_view_lib.ObjectGraphView(root_trackable).list_objects()
494
495
496def gather_initializers(root_trackable):
497  """Traverse the object graph and find initialization ops.
498
499  Looks for `Trackable` objects which are dependencies of
500  `root_trackable` and which have an `initializer` property. Includes
501  initializers for slot variables only if the variable they are slotting for and
502  the optimizer are dependencies of `root_trackable` (i.e. if they would be
503  saved with a checkpoint).
504
505  Args:
506    root_trackable: A `Trackable` object to gather initializers for.
507
508  Returns:
509    A list of initialization ops.
510  """
511  trackable_objects = list_objects(root_trackable)
512  return [
513      c.initializer
514      for c in trackable_objects
515      if hasattr(c, "initializer") and c.initializer is not None
516  ]
517
518
519@tf_contextlib.contextmanager
520def capture_dependencies(template):
521  """Capture variables created within this scope as `Template` dependencies.
522
523  Requires that `template.variable_scope` is active.
524
525  This scope is intended as a compatibility measure, allowing a trackable
526  object to add dependencies on variables created in a block of code which is
527  not aware of object-based saving (and instead uses variable names
528  heavily). This is how `Template` objects add dependencies on variables and
529  sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
530
531  Args:
532    template: The `Template` object to register dependencies with.
533
534  Yields:
535    None (when used as a context manager).
536  """
537  name_prefix = template.variable_scope.name
538
539  def _trackable_custom_creator(next_creator,
540                                name,
541                                initial_value,
542                                trackable_parent=None,
543                                **kwargs):
544    """A variable creation hook which adds Trackable dependencies.
545
546    Set for example during a `Template`'s first wrapped function
547    execution. Ensures that (a) `template` depends on any trackable
548    objects using their own `capture_dependencies` scope inside this scope which
549    create variables, and (b) that any variables not in a more deeply nested
550    scope are added as dependencies directly.
551
552    The `trackable_parent` argument is passed between custom creators but
553    ignored when the variable object itself is created. This argument indicates
554    (if not `None`) that a more deeply nested scope has already added the
555    variable as a dependency, and that parent scopes should add a dependency on
556    that object rather than on the variable directly.
557
558    Args:
559      next_creator: See `variable_scope.variable_creator_scope`; the next
560        creator in the chain.
561      name: The (full, scope-influenced) name of the variable. The `name_prefix`
562        itself is stripped for the purposes of object-based dependency tracking,
563        but scopes opened within this scope are respected.
564      initial_value: See `variable_scope.variable_creator_scope`. Taken
565        explicitly so the argument can be re-named and used with
566        `Trackable._add_variable_with_custom_getter`.
567      trackable_parent: If not None, a more deeply nested trackable object and
568        its name prefix which were passed to `capture_dependencies` to add a
569        dependency on (rather than depending on the variable directly).
570      **kwargs: Passed through to the next creator.
571
572    Returns:
573      The output of `next_creator`: the fetched/created variable object.
574    """
575
576    def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
577      inner_kwargs.pop("name")  # Ignored; this is the scope-stripped name which
578      # we don't want to propagate.
579      return next_creator(initial_value=initializer, name=name, **inner_kwargs)
580
581    if name is not None and name.startswith(name_prefix):
582      scope_stripped_name = name[len(name_prefix) + 1:]
583      if not trackable_parent:
584        return template._add_variable_with_custom_getter(  # pylint: disable=protected-access
585            initializer=initial_value,
586            name=scope_stripped_name,
587            getter=_call_next_creator_renaming_initializer,
588            # Disable error checking for Trackable. Exceptions are instead
589            # raised if necessary when the object-based saver tries to
590            # save/restore the object.
591            overwrite=True,
592            trackable_parent=(template, name_prefix),
593            **kwargs)
594      else:
595        parent_object, parent_name_prefix = trackable_parent
596        template._track_trackable(  # pylint: disable=protected-access
597            parent_object,
598            name=parent_name_prefix[len(name_prefix) + 1:],
599            overwrite=True)
600    return next_creator(
601        name=name,
602        initial_value=initial_value,
603        trackable_parent=(template, name_prefix),
604        **kwargs)
605
606  with variable_scope.variable_creator_scope(_trackable_custom_creator):
607    yield
608
609
610class _LoadStatus(object):
611  """Abstract base for load status callbacks."""
612
613  @abc.abstractmethod
614  def assert_consumed(self):
615    """Raises an exception unless a non-trivial restoration has completed."""
616    pass
617
618  @abc.abstractmethod
619  def assert_existing_objects_matched(self):
620    """Raises an exception unless existing Python objects have been matched."""
621    pass
622
623  @abc.abstractmethod
624  def assert_nontrivial_match(self):
625    """Raises an exception if only the root object matched."""
626    pass
627
628  @abc.abstractmethod
629  def run_restore_ops(self, session=None):
630    """Runs restore ops from the checkpoint. Requires a valid checkpoint."""
631    pass
632
633  @abc.abstractmethod
634  def initialize_or_restore(self, session=None):
635    """Runs restore ops from the checkpoint, or initializes variables."""
636    pass
637
638  def expect_partial(self):
639    """Silence warnings about incomplete checkpoint restores."""
640    return self
641
642
643def streaming_restore(status, session=None):
644  """When graph building, runs restore ops as soon as they come in.
645
646  Args:
647    status: A _LoadStatus objects from an object-based saver's restore().
648      Streaming restore from name-based checkpoints is not currently supported.
649    session: A session to run new restore ops in.
650  """
651  if context.executing_eagerly():
652    # Streaming restore is the default/only behavior when executing eagerly.
653    return
654  if session is None:
655    session = get_session()
656  if isinstance(status, NameBasedSaverStatus):
657    raise NotImplementedError(
658        "Streaming restore not supported from name-based checkpoints when "
659        "graph building. File a feature request if this limitation bothers "
660        "you. As a workaround, consider either using tf.train.Checkpoint to "
661        "load name-based checkpoints or enabling eager execution.")
662  status.run_restore_ops(session=session)
663  # pylint: disable=protected-access
664  status._checkpoint.new_restore_ops_callback = (
665      lambda ops: session.run(ops, feed_dict=status._feed_dict))
666  # pylint: enable=protected-access
667
668
669def _objects_with_attributes(full_list):
670  """Filters out objects with no direct variable dependencies for assertions."""
671  return [o for o in full_list if o._gather_saveables_for_checkpoint()]  # pylint: disable=protected-access
672
673
674class CheckpointLoadStatus(_LoadStatus):
675  """Checks the status of checkpoint loading and manages restore ops.
676
677  Returned from `Saver.restore`. Since `restore` may defer the loading of values
678  in the checkpoint which don't yet have corresponding Python objects,
679  `CheckpointLoadStatus` provides a callback to verify that checkpoint loading
680  is complete (`assert_consumed`).
681
682  When graph building, `restore` does not run restore ops itself since their
683  creation may be deferred. The `run_restore_ops` method must be called once all
684  Python objects with values to restore have been created and added to the
685  dependency graph (this does not necessarily have to be the whole checkpoint;
686  calling `run_restore_ops` while `assert_consumed` fails is supported and will
687  partially restore the checkpoint).
688
689  See `Saver.restore` for usage examples.
690  """
691
692  def __init__(self, checkpoint, feed_dict, graph_view):
693    self._checkpoint = checkpoint
694    self._feed_dict = feed_dict
695    self._graph_view = graph_view
696    # Keep a reference to the root, since graph_view might only have a weakref.
697    self._root = graph_view.root
698
699  def assert_consumed(self):
700    """Asserts that all objects in the checkpoint have been created/matched.
701
702    Returns:
703      `self` for chaining.
704    Raises:
705      AssertionError: If there are any Python objects in the dependency graph
706        which have not been restored from this checkpoint or a later `restore`,
707        or if there are any checkpointed values which have not been matched to
708        Python objects.
709    """
710    pretty_printer = _ObjectGraphProtoPrettyPrinter(
711        self._checkpoint.object_graph_proto)
712    self.assert_existing_objects_matched()
713    for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
714      if not node.attributes:
715        # Only raise exceptions for the nodes with attributes themselves. Either
716        # they're ultimately not important, or they have a child with an
717        # attribute.
718        continue
719      trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
720      if trackable is None:
721        raise AssertionError("Unresolved object in checkpoint {}: {}"
722                             .format(pretty_printer.node_names[node_id], node))
723    if self._checkpoint.slot_restorations:
724      # Sanity check; this collection should be clear if everything has been
725      # restored.
726      raise AssertionError("Unresolved slot restorations: %s" %
727                           (self._checkpoint.slot_restorations,))
728    if self._checkpoint.unused_attributes:
729      unused_attribute_messages = []
730      for node_id, attribute in six.iteritems(
731          self._checkpoint.unused_attributes):
732        obj = self._checkpoint.object_by_proto_id[node_id]
733        unused_attribute_messages.append(
734            "{} ({}): {}"
735            .format(pretty_printer.node_names[node_id], obj, attribute))
736      raise AssertionError(
737          ("Unused attributes in these objects (the attributes exist in the "
738           "checkpoint but were not restored):\n{}")
739          .format("\n".join(unused_attribute_messages)))
740    return self
741
742  def assert_existing_objects_matched(self):
743    """Asserts that trackable Python objects have been matched.
744
745    Note that this is a weaker assertion than `assert_consumed`. It will only
746    fail for existing Python objects which are (transitive) dependencies of the
747    root object and which do not have an entry in the checkpoint.
748
749    It will not fail, for example, if a `tf.keras.Layer` object has not yet been
750    built and so has not created any `tf.Variable` objects.
751
752    Returns:
753      `self` for chaining.
754
755    Raises:
756      AssertionError: If a Python object exists in the transitive dependencies
757        of the root object but does not have a value in the checkpoint.
758    """
759    for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
760      trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
761      if (trackable is not None and
762          trackable._update_uid < self._checkpoint.restore_uid):  # pylint: disable=protected-access
763        raise AssertionError("Object not assigned a value from checkpoint: %s" %
764                             (node,))
765    for trackable_object in self._graph_view.list_objects():
766      # Remove data structures that do not contain any variables from
767      # restoration checks.
768      if (isinstance(trackable_object,
769                     data_structures.TrackableDataStructure) and
770          not trackable_object._checkpoint_dependencies):
771        continue
772      self._checkpoint.all_python_objects.add(trackable_object)
773    unused_python_objects = (
774        object_identity.ObjectIdentitySet(
775            _objects_with_attributes(
776                self._checkpoint.all_python_objects)) -
777        object_identity.ObjectIdentitySet(
778            self._checkpoint.object_by_proto_id.values()))
779    if unused_python_objects:
780      raise AssertionError(
781          ("Some Python objects were not bound to checkpointed values, likely "
782           "due to changes in the Python program: %s") %
783          (list(unused_python_objects),))
784    return self
785
786  def assert_nontrivial_match(self):
787    """Raises an exception if only the root object matched."""
788    for trackable_object in self._graph_view.list_objects():
789      self._checkpoint.all_python_objects.add(trackable_object)
790    if len(self._checkpoint.object_by_proto_id) <= 1:
791      unused_python_objects = (
792          object_identity.ObjectIdentitySet(
793              _objects_with_attributes(self._checkpoint.all_python_objects))
794          - object_identity.ObjectIdentitySet(
795              self._checkpoint.object_by_proto_id.values()))
796      if unused_python_objects:
797        raise AssertionError(
798            ("Nothing except the root object matched a checkpointed value. "
799             "Typically this means that the checkpoint does not match the "
800             "Python program. The following objects have no matching "
801             "checkpointed value: %s") % (list(unused_python_objects),))
802      else:
803        raise AssertionError(
804            "Nothing to load. No dependencies have been added to %s yet." %
805            (self._graph_view.root,))
806    return self
807
808  def run_restore_ops(self, session=None):
809    """Run operations to restore objects in the dependency graph."""
810    if context.executing_eagerly():
811      return  # Run eagerly
812    if session is None:
813      session = get_session()
814    session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
815
816  def initialize_or_restore(self, session=None):
817    """Run operations to initialize or restore objects in the dependency graph.
818
819    Any objects in the dependency graph which have initializers but are not in
820    the checkpoint will have those initializers run, unless those variables are
821    being restored by a later call to `tf.train.Checkpoint.restore()`.
822
823    This method has a sibling in `InitializationOnlyStatus` which instead
824    initializes variables. That type is returned if no checkpoint is specified
825    in `Saver.restore`.
826
827    Args:
828      session: The session to run init/restore ops in. If `None`, uses the
829        default session.
830    """
831    if context.executing_eagerly():
832      return  # Initialization and restoration ops are run eagerly
833    if session is None:
834      session = get_session()
835    all_objects = self._graph_view.list_objects()
836    already_initialized_objects = object_identity.ObjectIdentitySet(
837        self._checkpoint.object_by_proto_id.values())
838    initializers_for_non_restored_variables = [
839        c.initializer for c in all_objects
840        if hasattr(c, "initializer")
841        and c not in already_initialized_objects
842        and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1)
843             < self._checkpoint.restore_uid)]
844    self.run_restore_ops(session=session)
845    session.run(initializers_for_non_restored_variables)
846
847  def expect_partial(self):
848    """Silence warnings about incomplete checkpoint restores."""
849    self._checkpoint.expect_partial = True
850    return self
851
852
853class InitializationOnlyStatus(_LoadStatus):
854  """Returned from `Saver.restore` when no checkpoint has been specified.
855
856  Objects of this type have the same `assert_consumed` method as
857  `CheckpointLoadStatus`, but it always fails. However,
858  `initialize_or_restore` works on objects of both types, and will
859  initialize variables in `InitializationOnlyStatus` objects or restore them
860  otherwise.
861  """
862
863  def __init__(self, graph_view, restore_uid):
864    self._restore_uid = restore_uid
865    self._graph_view = graph_view
866    # Keep a reference to the root, since graph_view might only have a weakref.
867    self._root = graph_view.root
868
869  def assert_consumed(self):
870    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
871    raise AssertionError(
872        "No checkpoint specified (save_path=None); nothing is being restored.")
873
874  def assert_existing_objects_matched(self):
875    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
876    raise AssertionError(
877        "No checkpoint specified (save_path=None); nothing is being restored.")
878
879  def assert_nontrivial_match(self):
880    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
881    raise AssertionError(
882        "No checkpoint specified (save_path=None); nothing is being restored.")
883
884  def run_restore_ops(self, session=None):
885    """For consistency with `CheckpointLoadStatus`.
886
887    Use `initialize_or_restore` for initializing if no checkpoint was passed
888    to `Saver.restore` and restoring otherwise.
889
890    Args:
891      session: Not used.
892    """
893    raise AssertionError(
894        "No checkpoint specified, so no restore ops are available "
895        "(save_path=None to Saver.restore).")
896
897  def initialize_or_restore(self, session=None):
898    """Runs initialization ops for variables.
899
900    Objects which would be saved by `Saver.save` will be initialized, unless
901    those variables are being restored by a later call to
902    `tf.train.Checkpoint.restore()`.
903
904    This method does nothing when executing eagerly (initializers get run
905    eagerly).
906
907    Args:
908      session: The session to run initialization ops in. If `None`, uses the
909        default session.
910    """
911    if context.executing_eagerly():
912      return  # run eagerly
913    if session is None:
914      session = get_session()
915    trackable_objects = self._graph_view.list_objects()
916    initializers = [
917        c.initializer for c in trackable_objects
918        if hasattr(c, "initializer") and c.initializer is not None
919        and (getattr(c, "_update_uid", self._restore_uid - 1)
920             < self._restore_uid)]
921    session.run(initializers)
922
923
924_DEPRECATED_RESTORE_INSTRUCTIONS = (
925    "Restoring a name-based tf.train.Saver checkpoint using the object-based "
926    "restore API. This mode uses global names to match variables, and so is "
927    "somewhat fragile. It also adds new restore ops to the graph each time it "
928    "is called when graph building. Prefer re-encoding training checkpoints in "
929    "the object-based format: run save() on the object-based saver (the same "
930    "one this message is coming from) and use that checkpoint in the future.")
931
932
933class NameBasedSaverStatus(_LoadStatus):
934  """Status for loading a name-based training checkpoint."""
935
936  # Ideally this deprecation decorator would be on the class, but that
937  # interferes with isinstance checks.
938  @deprecation.deprecated(
939      date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
940  def __init__(self, checkpoint, graph_view):
941    self._checkpoint = checkpoint
942    self._graph_view = graph_view
943    self._optionally_restored = []
944    # Keep a reference to the root, since graph_view might only have a weakref.
945    self._root = graph_view.root
946
947  def add_to_optionally_restored(self, var):
948    """Add a variable to the list of optionally restored variables.
949
950    There are situations where certain variables should be ignored in assertions
951    such as assert_existing_objects_matched(). One example is that of a
952    checkpoint saved with train.Saver(), and restored with train.Checkpoint():
953    it is possible for the train.Saver() checkpoint to be missing the internal
954    `save_counter` variable, which we want to ignore on restore.
955
956    Args:
957      var: The variable to treat as optionally restored.
958    """
959    self._optionally_restored.append(var)
960
961  def assert_consumed(self):
962    """Raises an exception if any variables are unmatched."""
963    unused_attributes = list(self._checkpoint.unused_attributes.items())
964    unused_attributes = [
965        a for a in unused_attributes
966        if all(a[0] is not x for x in self._optionally_restored)
967    ]
968    if unused_attributes:
969      unused_attribute_strings = [
970          "\n    {}: {}".format(obj, attributes)
971          for obj, attributes in unused_attributes
972      ]
973      raise AssertionError(
974          "Some objects had attributes which were not restored:{}".format(
975              "".join(unused_attribute_strings)))
976    for trackable in self._graph_view.list_objects():
977      # pylint: disable=protected-access
978      trackable._maybe_initialize_trackable()
979      if trackable._update_uid < self._checkpoint.restore_uid:
980        raise AssertionError("Object not restored: %s" % (trackable,))
981      # pylint: enable=protected-access
982    return self
983
984  def assert_existing_objects_matched(self):
985    """Raises an exception if currently created objects are unmatched."""
986    # For name-based checkpoints there's no object information in the
987    # checkpoint, so there's no distinction between
988    # assert_existing_objects_matched and assert_consumed (and both are less
989    # useful since we don't touch Python objects or Python state).
990    return self.assert_consumed()
991
992  def assert_nontrivial_match(self):
993    """Raises an exception if currently created objects are unmatched."""
994    # For name-based checkpoints there's no object information in the
995    # checkpoint, so there's no distinction between
996    # assert_nontrivial_match and assert_consumed (and both are less
997    # useful since we don't touch Python objects or Python state).
998    return self.assert_consumed()
999
1000  def _gather_saveable_objects(self):
1001    """Walk the object graph, using global names for SaveableObjects."""
1002    objects = self._graph_view.list_objects()
1003    saveable_objects = []
1004    for trackable in objects:
1005      # pylint: disable=protected-access
1006      trackable._maybe_initialize_trackable()
1007      if trackable._update_uid < self._checkpoint.restore_uid:
1008        trackable._update_uid = self._checkpoint.restore_uid
1009      else:
1010        continue
1011      # pylint: enable=protected-access
1012      saveable_objects.extend(
1013          self._checkpoint.globally_named_object_attributes(trackable))
1014    return saveable_objects
1015
1016  def run_restore_ops(self, session=None):
1017    """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
1018    if context.executing_eagerly():
1019      return  # Nothing to do, variables are restored on creation.
1020    if session is None:
1021      session = get_session()
1022    with ops.device("/cpu:0"):
1023      saveables = self._gather_saveable_objects()
1024      v1_saver_lib.Saver(saveables).restore(
1025          sess=session, save_path=self._checkpoint.save_path)
1026
1027  def initialize_or_restore(self, session=None):
1028    """Alias for `run_restore_ops`."""
1029    self.run_restore_ops(session=session)
1030
1031
1032class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
1033  """Pretends to be a session, inserts extra feeds on run()."""
1034
1035  def __init__(self, session, feed_additions):
1036    self._wrapped_session = session
1037    self._feed_additions = feed_additions
1038
1039  def run(self, fetches, feed_dict=None, **kwargs):
1040    if feed_dict is None:
1041      feed_dict = {}
1042    else:
1043      feed_dict = feed_dict.copy()
1044    feed_dict.update(self._feed_additions)
1045    return self._wrapped_session.run(
1046        fetches=fetches, feed_dict=feed_dict, **kwargs)
1047
1048
1049class TrackableSaver(object):
1050  """Saves and restores a `Trackable` object and its dependencies.
1051
1052  See `Trackable` for details of dependency management. `Saver` wraps
1053  `tf.compat.v1.train.Saver` for saving, including extra information about the
1054  graph of
1055  dependencies between Python objects. When restoring, it uses this information
1056  about the save-time dependency graph to more robustly match objects with their
1057  checkpointed values. When executing eagerly, it supports restoring variables
1058  on object creation (see `Saver.restore`).
1059
1060  Values in a checkpoint are mapped to `Trackable` Python objects
1061  (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the
1062  checkpoint was written. To avoid breaking existing checkpoints when modifying
1063  a class, dependency names (the names of attributes to which `Trackable`
1064  objects are assigned) may not change. These names are local to objects, in
1065  contrast to the `Variable.name`-based save/restore from
1066  `tf.compat.v1.train.Saver`, and
1067  so allow additional program transformations.
1068  """
1069
1070  def __init__(self, graph_view):
1071    """Configure saving.
1072
1073    Args:
1074      graph_view: A `GraphView` object containing a description of the object
1075        graph to save.
1076    """
1077    # The file prefix placeholder is created lazily when graph building (and not
1078    # at all when executing eagerly) to avoid creating ops in the constructor
1079    # (when they may never be necessary).
1080    self._file_prefix_placeholder = None
1081
1082    # Op caching for save
1083    self._object_graph_feed_tensor = None
1084    self._last_save_object_graph = None
1085    self._file_prefix_feed_tensor = None
1086    self._cached_save_operation = None
1087
1088    # Op caching for restore, shared between _CheckpointRestoreCoordinators
1089    self._restore_op_cache = {}
1090    self._graph_view = graph_view
1091
1092  def _gather_saveables(self, object_graph_tensor=None):
1093    """Wraps _serialize_object_graph to include the object graph proto."""
1094    (named_saveable_objects, graph_proto,
1095     feed_additions) = self._graph_view.serialize_object_graph()
1096    if object_graph_tensor is None:
1097      with ops.device("/cpu:0"):
1098        object_graph_tensor = constant_op.constant(
1099            graph_proto.SerializeToString(), dtype=dtypes.string)
1100    else:
1101      feed_additions.update(
1102          {object_graph_tensor: graph_proto.SerializeToString()})
1103    assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
1104    named_saveable_objects.append(
1105        base.NoRestoreSaveable(
1106            tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
1107    return named_saveable_objects, graph_proto, feed_additions
1108
1109  def _save_cached_when_graph_building(self,
1110                                       file_prefix,
1111                                       object_graph_tensor=None):
1112    """Create or retrieve save ops.
1113
1114    Args:
1115      file_prefix: The prefix for saved checkpoint files.
1116      object_graph_tensor: A `Tensor` to which the current object graph will be
1117        fed.
1118
1119    Returns:
1120      A two-element tuple with a filename tensor and a feed_dict of tensors to
1121      feed when running it (if graph building). The feed dict contains the
1122      current object graph and any Python state to be saved in the
1123      checkpoint. When executing eagerly only the first argument is meaningful.
1124    """
1125    (named_saveable_objects, graph_proto,
1126     feed_additions) = self._gather_saveables(
1127         object_graph_tensor=object_graph_tensor)
1128    if (self._last_save_object_graph != graph_proto
1129        # When executing eagerly, we need to re-create SaveableObjects each time
1130        # save() is called so they pick up new Tensors passed to their
1131        # constructors. That means the Saver needs to be copied with a new
1132        # var_list.
1133        or context.executing_eagerly() or ops.inside_function()):
1134      saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
1135      save_op = saver.save(file_prefix)
1136      with ops.device("/cpu:0"):
1137        with ops.control_dependencies([save_op]):
1138          self._cached_save_operation = array_ops.identity(file_prefix)
1139      self._last_save_object_graph = graph_proto
1140    return self._cached_save_operation, feed_additions
1141
1142  def save(self, file_prefix, checkpoint_number=None, session=None):
1143    """Save a training checkpoint.
1144
1145    The saved checkpoint includes variables created by this object and any
1146    Trackable objects it depends on at the time `Saver.save()` is called.
1147
1148    Args:
1149      file_prefix: A prefix to use for the checkpoint filenames
1150        (/path/to/directory/and_a_prefix). Names are generated based on this
1151        prefix and `checkpoint_number`, if provided.
1152      checkpoint_number: An integer variable or Tensor, used to number
1153        checkpoints. Typically this value is saved along with other variables in
1154        training checkpoints, which will happen automatically if it was created
1155        by `root_trackable` or one of its dependencies (via
1156        `Trackable._add_variable`).
1157      session: The session to evaluate variables in. Ignored when executing
1158        eagerly. If not provided when graph building, the default session is
1159        used.
1160
1161    Returns:
1162      The full path to the checkpoint.
1163    """
1164    feed_dict = {}
1165    use_session = (not context.executing_eagerly() and
1166                   not ops.inside_function())
1167    if checkpoint_number:
1168      file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
1169    if use_session:
1170      if self._object_graph_feed_tensor is None:
1171        with ops.device("/cpu:0"):
1172          self._object_graph_feed_tensor = constant_op.constant(
1173              "", dtype=dtypes.string)
1174          self._file_prefix_feed_tensor = constant_op.constant(
1175              "", dtype=dtypes.string)
1176      object_graph_tensor = self._object_graph_feed_tensor
1177      file_prefix_tensor = self._file_prefix_feed_tensor
1178      feed_dict[file_prefix_tensor] = file_prefix
1179    else:
1180      with ops.device("/cpu:0"):
1181        file_prefix_tensor = constant_op.constant(
1182            file_prefix, dtype=dtypes.string)
1183      object_graph_tensor = None
1184
1185    file_io.recursive_create_dir(os.path.dirname(file_prefix))
1186    save_path, new_feed_additions = self._save_cached_when_graph_building(
1187        file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
1188    if new_feed_additions:
1189      feed_dict.update(new_feed_additions)
1190    if not use_session:
1191      session = None
1192    elif session is None:
1193      session = get_session()
1194
1195    if session:
1196      return session.run(save_path, feed_dict=feed_dict)
1197    else:
1198      return save_path
1199
1200  def restore(self, save_path):
1201    """Restore a training checkpoint.
1202
1203    Restores `root_trackable` and any objects that it tracks
1204    (transitive). Either assigns values immediately if variables to restore have
1205    been created already, or defers restoration until the variables are
1206    created. Dependencies added to the `root_trackable` passed to the
1207    constructor after this call will be matched if they have a corresponding
1208    object in the checkpoint.
1209
1210    When building a graph, restorations are added to the graph but not run.
1211
1212    To disallow deferred loading, assert immediately that all checkpointed
1213    variables have been matched to variable objects:
1214
1215    ```python
1216    saver = Saver(root)
1217    saver.restore(path).assert_consumed()
1218    ```
1219
1220    An exception will be raised unless every object was matched and its
1221    variables already exist.
1222
1223    When graph building, `assert_consumed()` indicates that all of the restore
1224    ops which will be created for this checkpoint have been created. They can be
1225    run via the `run_restore_ops()` function of the status object:
1226
1227    ```python
1228    saver.restore(path).assert_consumed().run_restore_ops()
1229    ```
1230
1231    If the checkpoint has not been consumed completely, then the list of restore
1232    ops will grow as more objects are added to the dependency graph.
1233
1234    Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1235    method. There is no deferred loading, and names are used to match
1236    variables. No restore ops are created/run until `run_restore_ops()` or
1237    `initialize_or_restore()` are called on the returned status object, even
1238    when executing eagerly. Re-encode name-based checkpoints using this
1239    object-based `Saver.save` as soon as possible.
1240
1241    Args:
1242      save_path: The path to the checkpoint, as returned by `save` or
1243        `tf.train.latest_checkpoint`. If None (as when there is no latest
1244        checkpoint for `tf.train.latest_checkpoint` to return), returns an
1245        object which may run initializers for objects in the dependency graph.
1246        If the checkpoint was written by the name-based
1247        `tf.compat.v1.train.Saver`, names are used to match variables.
1248
1249    Returns:
1250      A load status object, which can be used to make assertions about the
1251      status of checkpoint restoration and run initialization/restore ops
1252      (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
1253      `save_path` is `None`).
1254
1255      If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
1256      object is returned which runs restore ops from a name-based saver.
1257    """
1258    if save_path is None:
1259      return InitializationOnlyStatus(self._graph_view, ops.uid())
1260    reader = py_checkpoint_reader.NewCheckpointReader(save_path)
1261    graph_building = not context.executing_eagerly()
1262    if graph_building:
1263      dtype_map = None
1264    else:
1265      dtype_map = reader.get_variable_to_dtype_map()
1266    try:
1267      object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
1268    except errors_impl.NotFoundError:
1269      # The object graph proto does not exist in this checkpoint. Try the
1270      # name-based compatibility mode.
1271      restore_coordinator = _NameBasedRestoreCoordinator(
1272          save_path=save_path,
1273          dtype_map=dtype_map)
1274      if not graph_building:
1275        for existing_trackable in self._graph_view.list_objects():
1276          # pylint: disable=protected-access
1277          existing_trackable._maybe_initialize_trackable()
1278          existing_trackable._name_based_restores.add(restore_coordinator)
1279          existing_trackable._name_based_attribute_restore(restore_coordinator)
1280          # pylint: enable=protected-access
1281      return NameBasedSaverStatus(
1282          restore_coordinator,
1283          graph_view=self._graph_view)
1284
1285    if graph_building:
1286      if self._file_prefix_placeholder is None:
1287        with ops.device("/cpu:0"):
1288          self._file_prefix_placeholder = constant_op.constant("model")
1289      file_prefix_tensor = self._file_prefix_placeholder
1290      file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
1291    else:
1292      with ops.device("/cpu:0"):
1293        file_prefix_tensor = constant_op.constant(save_path)
1294      file_prefix_feed_dict = None
1295    object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1296    object_graph_proto.ParseFromString(object_graph_string)
1297    checkpoint = _CheckpointRestoreCoordinator(
1298        object_graph_proto=object_graph_proto,
1299        save_path=save_path,
1300        save_path_tensor=file_prefix_tensor,
1301        restore_op_cache=self._restore_op_cache,
1302        graph_view=self._graph_view)
1303    base.CheckpointPosition(
1304        checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
1305    load_status = CheckpointLoadStatus(
1306        checkpoint,
1307        graph_view=self._graph_view,
1308        feed_dict=file_prefix_feed_dict)
1309    return load_status
1310
1311
1312def frozen_saver(root_trackable):
1313  """Creates a static `tf.compat.v1.train.Saver` from a trackable object.
1314
1315  The returned `Saver` saves object-based checkpoints, but these checkpoints
1316  will no longer reflect structural changes to the object graph, only changes to
1317  the values of `Variable`s added as dependencies of the root object before
1318  `freeze` was called.
1319
1320  `restore` works on the returned `Saver`, but requires that the object graph of
1321  the checkpoint being loaded exactly matches the object graph when `freeze` was
1322  called. This is in contrast the object-based restore performed by
1323  `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
1324  object graph and the current Python object graph.
1325
1326  Args:
1327    root_trackable: A trackable object to save.
1328
1329  Returns:
1330    A saver which saves object-based checkpoints for the object graph frozen at
1331    the time `frozen_saver` was called.
1332  """
1333  named_saveable_objects = graph_view_lib.ObjectGraphView(
1334      root_trackable).frozen_saveable_objects()
1335  return functional_saver.MultiDeviceSaver(named_saveable_objects)
1336
1337
1338def saver_with_op_caching(obj):
1339  """A TrackableSaver with a SaveableObject cache when graph building."""
1340  if context.executing_eagerly():
1341    saveables_cache = None
1342  else:
1343    saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
1344  return TrackableSaver(
1345      graph_view_lib.ObjectGraphView(
1346          weakref.ref(obj), saveables_cache=saveables_cache))
1347
1348
1349# Mentions graph building / Sessions. The v2 version is below.
1350@tf_export(v1=["train.Checkpoint"])
1351class CheckpointV1(tracking.AutoTrackable):
1352  """Groups trackable objects, saving and restoring them.
1353
1354  `Checkpoint`'s constructor accepts keyword arguments whose values are types
1355  that contain trackable state, such as `tf.compat.v1.train.Optimizer`
1356  implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
1357  `tf.keras.Model` implementations. It saves these values with a checkpoint, and
1358  maintains a `save_counter` for numbering checkpoints.
1359
1360  Example usage when graph building:
1361
1362  ```python
1363  import tensorflow as tf
1364  import os
1365
1366  checkpoint_directory = "/tmp/training_checkpoints"
1367  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1368
1369  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1370  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1371  train_op = optimizer.minimize( ... )
1372  status.assert_consumed()  # Optional sanity checks.
1373  with tf.compat.v1.Session() as session:
1374    # Use the Session to restore variables, or initialize them if
1375    # tf.train.latest_checkpoint returned None.
1376    status.initialize_or_restore(session)
1377    for _ in range(num_training_steps):
1378      session.run(train_op)
1379    checkpoint.save(file_prefix=checkpoint_prefix)
1380  ```
1381
1382  Example usage with eager execution enabled:
1383
1384  ```python
1385  import tensorflow as tf
1386  import os
1387
1388  tf.compat.v1.enable_eager_execution()
1389
1390  checkpoint_directory = "/tmp/training_checkpoints"
1391  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1392
1393  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1394  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1395  for _ in range(num_training_steps):
1396    optimizer.minimize( ... )  # Variables will be restored on creation.
1397  status.assert_consumed()  # Optional sanity checks.
1398  checkpoint.save(file_prefix=checkpoint_prefix)
1399  ```
1400
1401  `Checkpoint.save` and `Checkpoint.restore` write and read object-based
1402  checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
1403  `variable.name` based checkpoints. Object-based checkpointing saves a graph of
1404  dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
1405  etc.) with named edges, and this graph is used to match variables when
1406  restoring a checkpoint. It can be more robust to changes in the Python
1407  program, and helps to support restore-on-create for variables when executing
1408  eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
1409  code.
1410
1411  `Checkpoint` objects have dependencies on the objects passed as keyword
1412  arguments to their constructors, and each dependency is given a name that is
1413  identical to the name of the keyword argument for which it was created.
1414  TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
1415  dependencies on their variables (e.g. "kernel" and "bias" for
1416  `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
1417  dependencies easy in user-defined classes, since `Model` hooks into attribute
1418  assignment. For example:
1419
1420  ```python
1421  class Regress(tf.keras.Model):
1422
1423    def __init__(self):
1424      super(Regress, self).__init__()
1425      self.input_transform = tf.keras.layers.Dense(10)
1426      # ...
1427
1428    def call(self, inputs):
1429      x = self.input_transform(inputs)
1430      # ...
1431  ```
1432
1433  This `Model` has a dependency named "input_transform" on its `Dense` layer,
1434  which in turn depends on its variables. As a result, saving an instance of
1435  `Regress` using `tf.train.Checkpoint` will also save all the variables created
1436  by the `Dense` layer.
1437
1438  When variables are assigned to multiple workers, each worker writes its own
1439  section of the checkpoint. These sections are then merged/re-indexed to behave
1440  as a single checkpoint. This avoids copying all variables to one worker, but
1441  does require that all workers see a common filesystem.
1442
1443  While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
1444  same format, note that the root of the resulting checkpoint is the object the
1445  save method is attached to. This means saving a `tf.keras.Model` using
1446  `save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
1447  attached (or vice versa) will not match the `Model`'s variables. See the
1448  [guide to training
1449  checkpoints](https://www.tensorflow.org/guide/checkpoint) for
1450  details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
1451  training checkpoints.
1452
1453  Attributes:
1454    save_counter: Incremented when `save()` is called. Used to number
1455      checkpoints.
1456  """
1457
1458  def __init__(self, **kwargs):
1459    """Group objects into a training checkpoint.
1460
1461    Args:
1462      **kwargs: Keyword arguments are set as attributes of this object, and are
1463        saved with the checkpoint. Values must be trackable objects.
1464
1465    Raises:
1466      ValueError: If objects in `kwargs` are not trackable.
1467    """
1468    super(CheckpointV1, self).__init__()
1469    for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
1470      setattr(self, k, v)
1471      if not isinstance(
1472          getattr(self, k), (base.Trackable, def_function.Function)):
1473        raise ValueError(
1474            ("`Checkpoint` was expecting a trackable object (an object "
1475             "derived from `TrackableBase`), got %s. If you believe this "
1476             "object should be trackable (i.e. it is part of the "
1477             "TensorFlow Python API and manages state), please open an issue.")
1478            % (v,))
1479    self._save_counter = None  # Created lazily for restore-on-create.
1480    self._save_assign_op = None
1481    self._saver = saver_with_op_caching(self)
1482
1483  def _maybe_create_save_counter(self):
1484    """Create a save counter if it does not yet exist."""
1485    if self._save_counter is None:
1486      # Initialized to 0 and incremented before saving.
1487      with ops.device("/cpu:0"):
1488        # add_variable creates a dependency named "save_counter"; NoDependency
1489        # prevents creating a second dependency named "_save_counter".
1490        self._save_counter = data_structures.NoDependency(
1491            add_variable(
1492                self,
1493                name="save_counter",
1494                initializer=0,
1495                dtype=dtypes.int64,
1496                trainable=False))
1497
1498  def write(self, file_prefix, session=None):
1499    """Writes a training checkpoint.
1500
1501    The checkpoint includes variables created by this object and any
1502    trackable objects it depends on at the time `Checkpoint.write()` is
1503    called.
1504
1505    `write` does not number checkpoints, increment `save_counter`, or update the
1506    metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1507    use by higher level checkpoint management utilities. `save` provides a very
1508    basic implementation of these features.
1509
1510    Args:
1511      file_prefix: A prefix to use for the checkpoint filenames
1512        (/path/to/directory/and_a_prefix).
1513      session: The session to evaluate variables in. Ignored when executing
1514        eagerly. If not provided when graph building, the default session is
1515        used.
1516
1517    Returns:
1518      The full path to the checkpoint (i.e. `file_prefix`).
1519    """
1520    output = self._saver.save(file_prefix=file_prefix, session=session)
1521    if tensor_util.is_tensor(output):
1522      if context.executing_eagerly():
1523        return compat.as_str(output.numpy())
1524      else:
1525        # Function building
1526        return output
1527    else:
1528      # Graph + Session, so we already session.ran it.
1529      return compat.as_str(output)
1530
1531  @property
1532  def save_counter(self):
1533    """An integer variable which starts at zero and is incremented on save.
1534
1535    Used to number checkpoints.
1536
1537    Returns:
1538      The save counter variable.
1539    """
1540    self._maybe_create_save_counter()
1541    return self._save_counter
1542
1543  def save(self, file_prefix, session=None):
1544    """Saves a training checkpoint and provides basic checkpoint management.
1545
1546    The saved checkpoint includes variables created by this object and any
1547    trackable objects it depends on at the time `Checkpoint.save()` is
1548    called.
1549
1550    `save` is a basic convenience wrapper around the `write` method,
1551    sequentially numbering checkpoints using `save_counter` and updating the
1552    metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
1553    management, for example garbage collection and custom numbering, may be
1554    provided by other utilities which also wrap `write`
1555    (`tf.train.CheckpointManager` for example).
1556
1557    Args:
1558      file_prefix: A prefix to use for the checkpoint filenames
1559        (/path/to/directory/and_a_prefix). Names are generated based on this
1560        prefix and `Checkpoint.save_counter`.
1561      session: The session to evaluate variables in. Ignored when executing
1562        eagerly. If not provided when graph building, the default session is
1563        used.
1564
1565    Returns:
1566      The full path to the checkpoint.
1567    """
1568    graph_building = not context.executing_eagerly()
1569    if graph_building:
1570      if ops.inside_function():
1571        raise NotImplementedError(
1572            "Calling tf.train.Checkpoint.save() from a function is not "
1573            "supported, as save() modifies saving metadata in ways not "
1574            "supported by TensorFlow Operations. Consider using "
1575            "tf.train.Checkpoint.write(), a lower-level API which does not "
1576            "update metadata. tf.train.latest_checkpoint and related APIs will "
1577            "not see this checkpoint.")
1578      if session is None:
1579        session = get_session()
1580      if self._save_counter is None:
1581        # When graph building, if this is a new save counter variable then it
1582        # needs to be initialized before assign_add. This is only an issue if
1583        # restore() has not been called first.
1584        session.run(self.save_counter.initializer)
1585    if not graph_building or self._save_assign_op is None:
1586      with ops.colocate_with(self.save_counter):
1587        assign_op = self.save_counter.assign_add(1, read_value=True)
1588      if graph_building:
1589        self._save_assign_op = data_structures.NoDependency(assign_op)
1590    if graph_building:
1591      checkpoint_number = session.run(self._save_assign_op)
1592    else:
1593      checkpoint_number = assign_op.numpy()
1594    file_path = self.write(
1595        "%s-%d" % (file_prefix, checkpoint_number), session=session)
1596    checkpoint_management.update_checkpoint_state_internal(
1597        save_dir=os.path.dirname(file_prefix),
1598        model_checkpoint_path=file_path,
1599        all_model_checkpoint_paths=[file_path],
1600        save_relative_paths=True)
1601    return file_path
1602
1603  def restore(self, save_path):
1604    """Restore a training checkpoint.
1605
1606    Restores this `Checkpoint` and any objects it depends on.
1607
1608    When executing eagerly, either assigns values immediately if variables to
1609    restore have been created already, or defers restoration until the variables
1610    are created. Dependencies added after this call will be matched if they have
1611    a corresponding object in the checkpoint (the restore request will queue in
1612    any trackable object waiting for the expected dependency to be added).
1613
1614    When graph building, restoration ops are added to the graph but not run
1615    immediately.
1616
1617    To ensure that loading is complete and no more assignments will take place,
1618    use the `assert_consumed()` method of the status object returned by
1619    `restore`:
1620
1621    ```python
1622    checkpoint = tf.train.Checkpoint( ... )
1623    checkpoint.restore(path).assert_consumed()
1624    ```
1625
1626    An exception will be raised if any Python objects in the dependency graph
1627    were not found in the checkpoint, or if any checkpointed values do not have
1628    a matching Python object.
1629
1630    When graph building, `assert_consumed()` indicates that all of the restore
1631    ops that will be created for this checkpoint have been created. They can be
1632    run via the `run_restore_ops()` method of the status object:
1633
1634    ```python
1635    checkpoint.restore(path).assert_consumed().run_restore_ops()
1636    ```
1637
1638    If the checkpoint has not been consumed completely, then the list of restore
1639    ops will grow as more objects are added to the dependency graph.
1640
1641    Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1642    method. Names are used to match variables. No restore ops are created/run
1643    until `run_restore_ops()` or `initialize_or_restore()` are called on the
1644    returned status object when graph building, but there is restore-on-creation
1645    when executing eagerly. Re-encode name-based checkpoints using
1646    `tf.train.Checkpoint.save` as soon as possible.
1647
1648    Args:
1649      save_path: The path to the checkpoint, as returned by `save` or
1650        `tf.train.latest_checkpoint`. If None (as when there is no latest
1651        checkpoint for `tf.train.latest_checkpoint` to return), returns an
1652        object which may run initializers for objects in the dependency graph.
1653        If the checkpoint was written by the name-based
1654        `tf.compat.v1.train.Saver`, names are used to match variables.
1655
1656    Returns:
1657      A load status object, which can be used to make assertions about the
1658      status of a checkpoint restoration and run initialization/restore ops.
1659
1660      The returned status object has the following methods:
1661
1662      * `assert_consumed()`:
1663          Raises an exception if any variables are unmatched: either
1664          checkpointed values which don't have a matching Python object or
1665          Python objects in the dependency graph with no values in the
1666          checkpoint. This method returns the status object, and so may be
1667          chained with `initialize_or_restore` or `run_restore_ops`.
1668
1669      * `assert_existing_objects_matched()`:
1670          Raises an exception if any existing Python objects in the dependency
1671          graph are unmatched. Unlike `assert_consumed`, this assertion will
1672          pass if values in the checkpoint have no corresponding Python
1673          objects. For example a `tf.keras.Layer` object which has not yet been
1674          built, and so has not created any variables, will pass this assertion
1675          but fail `assert_consumed`. Useful when loading part of a larger
1676          checkpoint into a new Python program, e.g. a training checkpoint with
1677          a `tf.compat.v1.train.Optimizer` was saved but only the state required
1678          for
1679          inference is being loaded. This method returns the status object, and
1680          so may be chained with `initialize_or_restore` or `run_restore_ops`.
1681
1682      * `assert_nontrivial_match()`: Asserts that something aside from the root
1683          object was matched. This is a very weak assertion, but is useful for
1684          sanity checking in library code where objects may exist in the
1685          checkpoint which haven't been created in Python and some Python
1686          objects may not have a checkpointed value.
1687
1688      * `expect_partial()`: Silence warnings about incomplete checkpoint
1689          restores. Warnings are otherwise printed for unused parts of the
1690          checkpoint file or object when the `Checkpoint` object is deleted
1691          (often at program shutdown).
1692
1693      * `initialize_or_restore(session=None)`:
1694          When graph building, runs variable initializers if `save_path` is
1695          `None`, but otherwise runs restore operations. If no `session` is
1696          explicitly specified, the default session is used. No effect when
1697          executing eagerly (variables are initialized or restored eagerly).
1698
1699      * `run_restore_ops(session=None)`:
1700          When graph building, runs restore operations. If no `session` is
1701          explicitly specified, the default session is used. No effect when
1702          executing eagerly (restore operations are run eagerly). May only be
1703          called when `save_path` is not `None`.
1704    """
1705    status = self._saver.restore(save_path=save_path)
1706    # Create the save counter now so it gets initialized with other variables
1707    # when graph building. Creating it earlier would lead to errors when using,
1708    # say, train.Saver() to save the model before initializing it.
1709    self._maybe_create_save_counter()
1710    if isinstance(status, NameBasedSaverStatus):
1711      status.add_to_optionally_restored(self.save_counter)
1712    return status
1713
1714
1715@tf_export("train.Checkpoint", v1=[])
1716class Checkpoint(tracking.AutoTrackable):
1717  """Groups trackable objects, saving and restoring them.
1718
1719  `Checkpoint`'s constructor accepts keyword arguments whose values are types
1720  that contain trackable state, such as `tf.keras.optimizers.Optimizer`
1721  implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
1722  `tf.keras.Model` implementations. It saves these values with a checkpoint, and
1723  maintains a `save_counter` for numbering checkpoints.
1724
1725  Example usage:
1726
1727  ```python
1728  import tensorflow as tf
1729  import os
1730
1731  checkpoint_directory = "/tmp/training_checkpoints"
1732  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1733
1734  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1735  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1736  for _ in range(num_training_steps):
1737    optimizer.minimize( ... )  # Variables will be restored on creation.
1738  status.assert_consumed()  # Optional sanity checks.
1739  checkpoint.save(file_prefix=checkpoint_prefix)
1740  ```
1741
1742  `Checkpoint.save` and `Checkpoint.restore` write and read object-based
1743  checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
1744  writes and
1745  reads `variable.name` based checkpoints. Object-based checkpointing saves a
1746  graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
1747  `Variable`s, etc.) with named edges, and this graph is used to match variables
1748  when restoring a checkpoint. It can be more robust to changes in the Python
1749  program, and helps to support restore-on-create for variables.
1750
1751  `Checkpoint` objects have dependencies on the objects passed as keyword
1752  arguments to their constructors, and each dependency is given a name that is
1753  identical to the name of the keyword argument for which it was created.
1754  TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
1755  dependencies on their variables (e.g. "kernel" and "bias" for
1756  `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
1757  dependencies easy in user-defined classes, since `Model` hooks into attribute
1758  assignment. For example:
1759
1760  ```python
1761  class Regress(tf.keras.Model):
1762
1763    def __init__(self):
1764      super(Regress, self).__init__()
1765      self.input_transform = tf.keras.layers.Dense(10)
1766      # ...
1767
1768    def call(self, inputs):
1769      x = self.input_transform(inputs)
1770      # ...
1771  ```
1772
1773  This `Model` has a dependency named "input_transform" on its `Dense` layer,
1774  which in turn depends on its variables. As a result, saving an instance of
1775  `Regress` using `tf.train.Checkpoint` will also save all the variables created
1776  by the `Dense` layer.
1777
1778  When variables are assigned to multiple workers, each worker writes its own
1779  section of the checkpoint. These sections are then merged/re-indexed to behave
1780  as a single checkpoint. This avoids copying all variables to one worker, but
1781  does require that all workers see a common filesystem.
1782
1783  While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
1784  same format, note that the root of the resulting checkpoint is the object the
1785  save method is attached to. This means saving a `tf.keras.Model` using
1786  `save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
1787  attached (or vice versa) will not match the `Model`'s variables. See the
1788  [guide to training
1789  checkpoints](https://www.tensorflow.org/guide/checkpoint) for
1790  details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
1791  training checkpoints.
1792
1793  Attributes:
1794    save_counter: Incremented when `save()` is called. Used to number
1795      checkpoints.
1796  """
1797
1798  def __init__(self, **kwargs):
1799    """Group objects into a training checkpoint.
1800
1801    Args:
1802      **kwargs: Keyword arguments are set as attributes of this object, and are
1803        saved with the checkpoint. Values must be trackable objects.
1804
1805    Raises:
1806      ValueError: If objects in `kwargs` are not trackable.
1807    """
1808    super(Checkpoint, self).__init__()
1809    for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
1810      setattr(self, k, v)
1811      if not isinstance(
1812          getattr(self, k), (base.Trackable, def_function.Function)):
1813        raise ValueError(
1814            ("`Checkpoint` was expecting a trackable object (an object "
1815             "derived from `TrackableBase`), got %s. If you believe this "
1816             "object should be trackable (i.e. it is part of the "
1817             "TensorFlow Python API and manages state), please open an issue.")
1818            % (v,))
1819    self._save_counter = None  # Created lazily for restore-on-create.
1820    self._save_assign_op = None
1821    self._saver = saver_with_op_caching(self)
1822
1823  def _maybe_create_save_counter(self):
1824    """Create a save counter if it does not yet exist."""
1825    if self._save_counter is None:
1826      # Initialized to 0 and incremented before saving.
1827      with ops.device("/cpu:0"):
1828        # add_variable creates a dependency named "save_counter"; NoDependency
1829        # prevents creating a second dependency named "_save_counter".
1830        self._save_counter = data_structures.NoDependency(
1831            add_variable(
1832                self,
1833                name="save_counter",
1834                initializer=0,
1835                dtype=dtypes.int64,
1836                trainable=False))
1837
1838  def write(self, file_prefix):
1839    """Writes a training checkpoint.
1840
1841    The checkpoint includes variables created by this object and any
1842    trackable objects it depends on at the time `Checkpoint.write()` is
1843    called.
1844
1845    `write` does not number checkpoints, increment `save_counter`, or update the
1846    metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1847    use by higher level checkpoint management utilities. `save` provides a very
1848    basic implementation of these features.
1849
1850    Args:
1851      file_prefix: A prefix to use for the checkpoint filenames
1852        (/path/to/directory/and_a_prefix).
1853
1854    Returns:
1855      The full path to the checkpoint (i.e. `file_prefix`).
1856    """
1857    output = self._saver.save(file_prefix=file_prefix)
1858    if tensor_util.is_tensor(output):
1859      if context.executing_eagerly():
1860        return compat.as_str(output.numpy())
1861      else:
1862        # Function building
1863        return output
1864    else:
1865      # Graph + Session, so we already session.ran it.
1866      return compat.as_str(output)
1867
1868  @property
1869  def save_counter(self):
1870    """An integer variable which starts at zero and is incremented on save.
1871
1872    Used to number checkpoints.
1873
1874    Returns:
1875      The save counter variable.
1876    """
1877    self._maybe_create_save_counter()
1878    return self._save_counter
1879
1880  def save(self, file_prefix):
1881    """Saves a training checkpoint and provides basic checkpoint management.
1882
1883    The saved checkpoint includes variables created by this object and any
1884    trackable objects it depends on at the time `Checkpoint.save()` is
1885    called.
1886
1887    `save` is a basic convenience wrapper around the `write` method,
1888    sequentially numbering checkpoints using `save_counter` and updating the
1889    metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
1890    management, for example garbage collection and custom numbering, may be
1891    provided by other utilities which also wrap `write`
1892    (`tf.train.CheckpointManager` for example).
1893
1894    Args:
1895      file_prefix: A prefix to use for the checkpoint filenames
1896        (/path/to/directory/and_a_prefix). Names are generated based on this
1897        prefix and `Checkpoint.save_counter`.
1898
1899    Returns:
1900      The full path to the checkpoint.
1901    """
1902    graph_building = not context.executing_eagerly()
1903    if graph_building:
1904      if ops.inside_function():
1905        raise NotImplementedError(
1906            "Calling tf.train.Checkpoint.save() from a function is not "
1907            "supported, as save() modifies saving metadata in ways not "
1908            "supported by TensorFlow Operations. Consider using "
1909            "tf.train.Checkpoint.write(), a lower-level API which does not "
1910            "update metadata. tf.train.latest_checkpoint and related APIs will "
1911            "not see this checkpoint.")
1912      session = get_session()
1913      if self._save_counter is None:
1914        # When graph building, if this is a new save counter variable then it
1915        # needs to be initialized before assign_add. This is only an issue if
1916        # restore() has not been called first.
1917        session.run(self.save_counter.initializer)
1918    if not graph_building or self._save_assign_op is None:
1919      with ops.colocate_with(self.save_counter):
1920        assign_op = self.save_counter.assign_add(1, read_value=True)
1921      if graph_building:
1922        self._save_assign_op = data_structures.NoDependency(assign_op)
1923    if graph_building:
1924      checkpoint_number = session.run(self._save_assign_op)
1925    else:
1926      checkpoint_number = assign_op.numpy()
1927    file_path = self.write("%s-%d" % (file_prefix, checkpoint_number))
1928    checkpoint_management.update_checkpoint_state_internal(
1929        save_dir=os.path.dirname(file_prefix),
1930        model_checkpoint_path=file_path,
1931        all_model_checkpoint_paths=[file_path],
1932        save_relative_paths=True)
1933    return file_path
1934
1935  def restore(self, save_path):
1936    """Restore a training checkpoint.
1937
1938    Restores this `Checkpoint` and any objects it depends on.
1939
1940    Either assigns values immediately if variables to restore have been created
1941    already, or defers restoration until the variables are created. Dependencies
1942    added after this call will be matched if they have a corresponding object in
1943    the checkpoint (the restore request will queue in any trackable object
1944    waiting for the expected dependency to be added).
1945
1946    To ensure that loading is complete and no more assignments will take place,
1947    use the `assert_consumed()` method of the status object returned by
1948    `restore`:
1949
1950    ```python
1951    checkpoint = tf.train.Checkpoint( ... )
1952    checkpoint.restore(path).assert_consumed()
1953    ```
1954
1955    An exception will be raised if any Python objects in the dependency graph
1956    were not found in the checkpoint, or if any checkpointed values do not have
1957    a matching Python object.
1958
1959    Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
1960    loaded
1961    using this method. Names are used to match variables. Re-encode name-based
1962    checkpoints using `tf.train.Checkpoint.save` as soon as possible.
1963
1964    Args:
1965      save_path: The path to the checkpoint, as returned by `save` or
1966        `tf.train.latest_checkpoint`. If None (as when there is no latest
1967        checkpoint for `tf.train.latest_checkpoint` to return), returns an
1968        object which may run initializers for objects in the dependency graph.
1969        If the checkpoint was written by the name-based
1970        `tf.compat.v1.train.Saver`, names are used to match variables.
1971
1972    Returns:
1973      A load status object, which can be used to make assertions about the
1974      status of a checkpoint restoration.
1975
1976      The returned status object has the following methods:
1977
1978      * `assert_consumed()`:
1979          Raises an exception if any variables are unmatched: either
1980          checkpointed values which don't have a matching Python object or
1981          Python objects in the dependency graph with no values in the
1982          checkpoint. This method returns the status object, and so may be
1983          chained with other assertions.
1984
1985      * `assert_existing_objects_matched()`:
1986          Raises an exception if any existing Python objects in the dependency
1987          graph are unmatched. Unlike `assert_consumed`, this assertion will
1988          pass if values in the checkpoint have no corresponding Python
1989          objects. For example a `tf.keras.Layer` object which has not yet been
1990          built, and so has not created any variables, will pass this assertion
1991          but fail `assert_consumed`. Useful when loading part of a larger
1992          checkpoint into a new Python program, e.g. a training checkpoint with
1993          a `tf.compat.v1.train.Optimizer` was saved but only the state required
1994          for
1995          inference is being loaded. This method returns the status object, and
1996          so may be chained with other assertions.
1997
1998      * `assert_nontrivial_match()`: Asserts that something aside from the root
1999          object was matched. This is a very weak assertion, but is useful for
2000          sanity checking in library code where objects may exist in the
2001          checkpoint which haven't been created in Python and some Python
2002          objects may not have a checkpointed value.
2003
2004      * `expect_partial()`: Silence warnings about incomplete checkpoint
2005          restores. Warnings are otherwise printed for unused parts of the
2006          checkpoint file or object when the `Checkpoint` object is deleted
2007          (often at program shutdown).
2008    """
2009    status = self._saver.restore(save_path=save_path)
2010    # Create the save counter now so it gets initialized with other variables
2011    # when graph building. Creating it earlier would lead to errors when using,
2012    # say, train.Saver() to save the model before initializing it.
2013    self._maybe_create_save_counter()
2014    if isinstance(status, NameBasedSaverStatus):
2015      status.add_to_optionally_restored(self.save_counter)
2016    return status
2017