• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities for saving/loading Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22import functools
23import os
24import weakref
25
26import six
27
28from tensorflow.core.protobuf import trackable_object_graph_pb2
29from tensorflow.python.client import session as session_lib
30from tensorflow.python.eager import context
31from tensorflow.python.eager import def_function
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors_impl
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_shape
37from tensorflow.python.framework import tensor_util
38from tensorflow.python.lib.io import file_io
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import gen_io_ops as io_ops
41from tensorflow.python.ops import init_ops
42from tensorflow.python.ops import variable_scope
43from tensorflow.python.ops import variables
44from tensorflow.python.platform import gfile
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.saved_model import utils_impl
47from tensorflow.python.training import checkpoint_management
48from tensorflow.python.training import py_checkpoint_reader
49from tensorflow.python.training import saver as v1_saver_lib
50from tensorflow.python.training.saving import checkpoint_options
51from tensorflow.python.training.saving import functional_saver
52from tensorflow.python.training.saving import saveable_object_util
53from tensorflow.python.training.tracking import base
54from tensorflow.python.training.tracking import data_structures
55from tensorflow.python.training.tracking import graph_view as graph_view_lib
56from tensorflow.python.training.tracking import tracking
57from tensorflow.python.util import compat
58from tensorflow.python.util import deprecation
59from tensorflow.python.util import object_identity
60from tensorflow.python.util import tf_contextlib
61from tensorflow.python.util import tf_inspect
62from tensorflow.python.util.tf_export import tf_export
63
64
65# The callable that provide Keras default session that is needed for saving.
66_SESSION_PROVIDER = None
67
68
69def register_session_provider(session_provider):
70  global _SESSION_PROVIDER
71  if _SESSION_PROVIDER is None:
72    _SESSION_PROVIDER = session_provider
73
74
75def get_session():
76  # Prefer TF's default session since get_session from Keras has side-effects.
77  session = ops.get_default_session()
78  if session is None:
79    global _SESSION_PROVIDER
80    if _SESSION_PROVIDER is not None:
81      session = _SESSION_PROVIDER()  # pylint: disable=not-callable
82  return session
83
84
85class _ObjectGraphProtoPrettyPrinter(object):
86  """Lazily traverses an object graph proto to pretty print names.
87
88  If no calls to `node_names` are made this object has no performance
89  overhead. On the other hand, it will only traverse the object graph once, so
90  repeated naming is cheap after the first.
91  """
92
93  __slots__ = ["_object_graph_proto", "_node_name_cache"]
94
95  def __init__(self, object_graph_proto):
96    self._object_graph_proto = object_graph_proto
97    self._node_name_cache = None
98
99  @property
100  def node_names(self):
101    """Lazily creates a mapping from node id to ("path", "to", "root")."""
102    if self._node_name_cache is not None:
103      return self._node_name_cache
104    path_to_root = {}
105    path_to_root[0] = ("(root)",)
106    to_visit = collections.deque([0])
107    while to_visit:
108      node_id = to_visit.popleft()
109      obj = self._object_graph_proto.nodes[node_id]
110      for child in obj.children:
111        if child.node_id not in path_to_root:
112          path_to_root[child.node_id] = (
113              path_to_root[node_id] + (child.local_name,))
114          to_visit.append(child.node_id)
115
116    node_names = {}
117    for node_id, path_to_root in path_to_root.items():
118      node_names[node_id] = ".".join(path_to_root)
119
120    for node_id, node in enumerate(self._object_graph_proto.nodes):
121      for slot_reference in node.slot_variables:
122        node_names[slot_reference.slot_variable_node_id] = (
123            "{}'s state '{}' for {}".format(
124                node_names[node_id], slot_reference.slot_name,
125                node_names[slot_reference.original_variable_node_id]))
126    self._node_name_cache = node_names
127    return node_names
128
129
130class _CheckpointRestoreCoordinatorDeleter(object):
131  """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
132
133  __slots__ = [
134      "expect_partial", "object_graph_proto", "matched_proto_ids",
135      "unused_attributes"
136  ]
137
138  def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
139               unused_attributes):
140    self.expect_partial = expect_partial
141    self.object_graph_proto = object_graph_proto
142    self.matched_proto_ids = matched_proto_ids
143    self.unused_attributes = unused_attributes
144
145  def set_expect_partial(self, expect_partial):
146    self.expect_partial = expect_partial
147
148  def __del__(self):
149    if self.expect_partial:
150      return
151    if logging is None:
152      # The logging module may have been unloaded when __del__ is called.
153      log_fn = print
154    else:
155      log_fn = logging.warning
156    printed_warning = False
157    pretty_printer = _ObjectGraphProtoPrettyPrinter(self.object_graph_proto)
158    for node_id in range(len(self.object_graph_proto.nodes)):
159      if node_id not in self.matched_proto_ids:
160        log_fn("Unresolved object in checkpoint: {}"
161               .format(pretty_printer.node_names[node_id]))
162        printed_warning = True
163    for node_id, attribute_name in self.unused_attributes.items():
164      log_fn(("Unused attribute in object {}: {}"
165              .format(pretty_printer.node_names[node_id], attribute_name)))
166      printed_warning = True
167    if printed_warning:
168      log_fn(
169          "A checkpoint was restored (e.g. tf.train.Checkpoint.restore or "
170          "tf.keras.Model.load_weights) but not all checkpointed values were "
171          "used. See above for specific issues. Use expect_partial() on the "
172          "load status object, e.g. "
173          "tf.train.Checkpoint.restore(...).expect_partial(), to silence these "
174          "warnings, or use assert_consumed() to make the check explicit. See "
175          "https://www.tensorflow.org/guide/checkpoint#loading_mechanics"
176          " for details.")
177
178
179class _CheckpointRestoreCoordinator(object):
180  """Holds the status of an object-based checkpoint load."""
181
182  def __init__(self, object_graph_proto, save_path, save_path_tensor,
183               restore_op_cache, graph_view, options):
184    """Specify the checkpoint being loaded.
185
186    Args:
187      object_graph_proto: The TrackableObjectGraph protocol buffer associated
188        with this checkpoint.
189      save_path: A string, the path to the checkpoint, as returned by
190        `tf.train.latest_checkpoint`.
191      save_path_tensor: A string `Tensor` which contains or will be fed the save
192        path.
193      restore_op_cache: A dictionary shared between
194        `_CheckpointRestoreCoordinator`s for the same Python objects, used to
195        look up restore ops by name to avoid re-creating them across multiple
196        `restore()` calls.
197      graph_view: A graph_view_lib.ObjectGraphView object for the restored
198        objects.
199      options: A CheckpointOptions object.
200    """
201    self.options = options
202    self.object_graph_proto = object_graph_proto
203    self.restore_uid = ops.uid()
204    # Maps from proto ids to lists of attributes which were in the checkpoint
205    # but not loaded into any object, for error checking.
206    self.unused_attributes = {}
207    # Dictionary mapping from an id in the protocol buffer flat array to
208    # Trackable Python objects. This mapping may be deferred if a
209    # checkpoint is restored before all dependencies have been tracked. Uses
210    # weak references so that partial restorations don't create reference cycles
211    # (as objects with deferred dependencies will generally have references to
212    # this object).
213    self.object_by_proto_id = weakref.WeakValueDictionary()
214    self.matched_proto_ids = set()
215    # A set of all Python objects we've seen as dependencies, even if we didn't
216    # use them (for example because of inconsistent references when
217    # loading). Used to make status assertions fail when loading checkpoints
218    # that don't quite match.
219    self.all_python_objects = object_identity.ObjectIdentityWeakSet()
220    self.save_path_tensor = save_path_tensor
221    self.save_path_string = save_path
222    self.dtype_map = py_checkpoint_reader.NewCheckpointReader(
223        save_path).get_variable_to_dtype_map()
224    # A NewCheckpointReader for the most recent checkpoint, for streaming Python
225    # state restoration.
226    # When graph building, contains a list of ops to run to restore objects from
227    # this checkpoint.
228    self.restore_ops = []
229    self.restore_ops_by_name = restore_op_cache
230    self.graph_view = graph_view
231    self.new_restore_ops_callback = None
232    # A mapping from optimizer proto ids to lists of slot variables to be
233    # restored when the optimizer is tracked. Only includes slot variables whose
234    # regular variables have already been created, and only for optimizer
235    # objects which have not yet been created/tracked.
236    self.deferred_slot_restorations = {}
237    # A mapping from variable proto ids to lists of slot variables to be
238    # restored when the variable is created/tracked. These get shifted over to
239    # deferred_slot_restorations if the optimizer hasn't been created when that
240    # happens.
241    self.slot_restorations = {}
242    # Controls whether errors are printed in __del__ if some objects did not
243    # match.
244    self.expect_partial_attr = False
245    for node_index, node in enumerate(self.object_graph_proto.nodes):
246      for slot_reference in node.slot_variables:
247        # `node` refers to an `Optimizer`, since only these have slot variables.
248        self.slot_restorations.setdefault(
249            slot_reference.original_variable_node_id, []).append(
250                base._SlotVariableRestoration(  # pylint: disable=protected-access
251                    optimizer_id=node_index,
252                    slot_variable_id=slot_reference.slot_variable_node_id,
253                    slot_name=slot_reference.slot_name))
254
255    self._deleter = _CheckpointRestoreCoordinatorDeleter(
256        self.expect_partial_attr,
257        self.object_graph_proto,
258        self.matched_proto_ids,
259        self.unused_attributes)
260
261  @property
262  def expect_partial(self):
263    return self.expect_partial_attr
264
265  @expect_partial.setter
266  def expect_partial(self, expect_partial):
267    self.expect_partial_attr = expect_partial
268    self._deleter.set_expect_partial(expect_partial)
269
270  def new_restore_ops(self, new_ops):
271    self.restore_ops.extend(new_ops)
272    if self.new_restore_ops_callback:
273      self.new_restore_ops_callback(new_ops)  # pylint: disable=not-callable
274
275  def restore_saveables(self, tensor_saveables, python_saveables):
276    """Run or build restore operations for SaveableObjects.
277
278    Args:
279      tensor_saveables: `SaveableObject`s which correspond to Tensors.
280      python_saveables: `PythonStateSaveable`s which correspond to Python
281        values.
282
283    Returns:
284      When graph building, a list of restore operations, either cached or newly
285      created, to restore `tensor_saveables`.
286    """
287    restore_ops = []
288    # Eagerly run restorations for Python state.
289    reader = None
290    for saveable in python_saveables:
291      if reader is None:
292        # Lazily create the NewCheckpointReader, since this requires file access
293        # and we may not have any Python saveables.
294        reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
295      spec_names = [spec.name for spec in saveable.specs]
296      saveable.python_restore([reader.get_tensor(name) for name in spec_names])
297
298    # If we have new SaveableObjects, extract and cache restore ops.
299    if tensor_saveables:
300      validated_saveables = saveable_object_util.validate_and_slice_inputs(
301          tensor_saveables)
302      validated_names = set(saveable.name for saveable in validated_saveables)
303      if set(tensor_saveables.keys()) != validated_names:
304        raise AssertionError(
305            ("Saveable keys changed when validating. Got back %s, was "
306             "expecting %s") % (tensor_saveables.keys(), validated_names))
307      new_restore_ops = functional_saver.MultiDeviceSaver(
308          validated_saveables).restore(self.save_path_tensor, self.options)
309      if not context.executing_eagerly():
310        for name, restore_op in sorted(new_restore_ops.items()):
311          restore_ops.append(restore_op)
312          assert name not in self.restore_ops_by_name
313          self.restore_ops_by_name[name] = restore_op
314    return restore_ops
315
316
317class _NameBasedRestoreCoordinator(object):
318  """Keeps the status of a name-based checkpoint restore."""
319
320  def __init__(self, save_path, dtype_map=None):
321    self.save_path = save_path
322    self.dtype_map = dtype_map
323    # A map from trackable objects to unused attribute names. We don't have
324    # proto IDs when doing a name-based restore, so the map keys differ from
325    # those in _CheckpointRestoreCoordinator.
326    self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary()
327    self.restore_uid = ops.uid()
328
329  def globally_named_object_attributes(self, trackable):
330    """Create globally named SaveableObjects from attributes.
331
332    If an object's attribute has no global name specified (default construction
333    for the SaveableObject factory), records the failure in
334    `self.unused_attributes` (which can then be used to make status assertions
335    fail; see `NameBasedSaverStatus`).
336
337    Args:
338      trackable: An object to save.
339
340    Yields:
341      SaveableObjects for `trackable`'s attributes.
342    """
343    for attribute_name, saveable_factory in (
344        trackable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
345      if callable(saveable_factory):
346        try:
347          # This saveable object factory does not have a default name= argument,
348          # which means there's no way to save/restore it using a name-based
349          # checkpoint. Ignore the error now and make sure assert_consumed()
350          # fails.
351          saveable = saveable_factory()
352        except TypeError:
353          # Even if we can't name this object, we should construct it and check
354          # whether it's optional to restore it. If it's optional we don't need
355          # to make assertions fail.
356          if not saveable_factory("").optional_restore:
357            self.unused_attributes.setdefault(trackable,
358                                              []).append(attribute_name)
359          continue
360      else:
361        saveable = saveable_factory
362      names_to_saveables = saveable_object_util.op_list_to_dict(
363          [saveable], convert_variable_to_tensor=False)
364      for name, op in names_to_saveables.items():
365        for saveable_object in saveable_object_util.saveable_objects_for_op(
366            op=op, name=name):
367          yield saveable_object
368
369  def eager_restore(self, trackable):
370    """Runs restore ops for `trackable`'s attributes."""
371    # When graph building, we don't add any restore ops to the graph until
372    # run_restore_ops/initialize_or_restore on the status object for name-based
373    # checkpoints.
374    assert context.executing_eagerly()
375    for saveable in self.globally_named_object_attributes(trackable):
376      restored_tensors = []
377      tensor_missing = False
378      for spec in saveable.specs:
379        if spec.name in self.dtype_map:
380          with ops.device("cpu:0"):
381            restored, = io_ops.restore_v2(
382                prefix=self.save_path,
383                tensor_names=[spec.name],
384                shape_and_slices=[""],
385                dtypes=[self.dtype_map[spec.name]],
386                name="%s_checkpoint_read" % (spec.name,))
387          restored_tensors.append(array_ops.identity(restored))
388        else:
389          tensor_missing = True
390
391      if tensor_missing:
392        # Record that this variable didn't match so assertions will fail.
393        self.unused_attributes.setdefault(trackable, []).append(saveable.name)
394      else:
395        # Ignores values missing from the checkpoint, as with object-based
396        # restore. Status assertions can be used to check exact matches,
397        # although it's unlikely to ever happen for name-based checkpoints.
398        saveable.restore(
399            restored_tensors=restored_tensors, restored_shapes=None)
400
401
402# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change
403# or consolidating the implementation with get_variable.
404def _default_getter(name,
405                    shape,
406                    dtype,
407                    initializer=None,
408                    partition_info=None,
409                    **kwargs):
410  """A pared-down version of get_variable which does not reuse variables."""
411  dtype = dtypes.as_dtype(dtype)
412  shape_object = tensor_shape.as_shape(shape)
413  with ops.init_scope():
414    if initializer is None:
415      initializer, initializing_from_value = (
416          variable_scope._get_default_variable_store()._get_default_initializer(  # pylint: disable=protected-access
417              name=name,
418              shape=shape_object,
419              dtype=dtype))
420    else:
421      initializing_from_value = not callable(initializer)
422    # Same logic as get_variable
423    variable_dtype = dtype.base_dtype
424    if initializing_from_value:
425      if shape is not None:
426        raise ValueError("If initializer is a constant, do not specify shape.")
427      initial_value = initializer
428    else:
429      # Instantiate initializer if provided initializer is a type object.
430      if isinstance(initializer, type(init_ops.Initializer)):
431        initializer = initializer(dtype=dtype)
432      shape_list = None if shape is None else shape_object.as_list()
433      if "partition_info" in tf_inspect.getargspec(initializer).args:
434        initial_value = functools.partial(initializer,
435                                          shape_list,
436                                          dtype=dtype,
437                                          partition_info=partition_info)
438      else:
439        initial_value = functools.partial(initializer,
440                                          shape_list,
441                                          dtype=dtype)
442
443    return variables.VariableV1(
444        initial_value=initial_value,
445        name=name,
446        dtype=variable_dtype,
447        use_resource=True,
448        **kwargs)
449
450
451def add_variable(trackable,
452                 name,
453                 shape=None,
454                 dtype=dtypes.float32,
455                 initializer=None,
456                 trainable=True):
457  """Add a variable to a Trackable with no scope influence."""
458  return trackable._add_variable_with_custom_getter(  # pylint: disable=protected-access
459      name=name,
460      shape=shape,
461      dtype=dtype,
462      initializer=initializer,
463      getter=_default_getter,
464      trainable=trainable)
465
466
467def object_metadata(save_path):
468  """Retrieves information about the objects in a checkpoint.
469
470  Example usage:
471
472  ```python
473  object_graph = tf.contrib.checkpoint.object_metadata(
474      tf.train.latest_checkpoint(checkpoint_directory))
475  ckpt_variable_names = set()
476  for node in object_graph.nodes:
477    for attribute in node.attributes:
478      ckpt_variable_names.add(attribute.full_name)
479  ```
480
481  Args:
482    save_path: The path to the checkpoint, as returned by `save` or
483      `tf.train.latest_checkpoint`.
484
485  Returns:
486    A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
487  Raises:
488    ValueError: If an object graph was not found in the checkpoint.
489  """
490  reader = py_checkpoint_reader.NewCheckpointReader(save_path)
491  try:
492    object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
493  except errors_impl.NotFoundError:
494    raise ValueError(
495        ('The specified checkpoint "%s" does not appear to be object-based (it '
496         'is missing the key "%s"). Likely it was created with a name-based '
497         "saver and does not contain an object dependency graph.") %
498        (save_path, base.OBJECT_GRAPH_PROTO_KEY))
499  object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
500  object_graph_proto.ParseFromString(object_graph_string)
501  return object_graph_proto
502
503
504def list_objects(root_trackable):
505  """Traverse the object graph and list all accessible objects.
506
507  Looks for `Trackable` objects which are dependencies of
508  `root_trackable`. Includes slot variables only if the variable they are
509  slotting for and the optimizer are dependencies of `root_trackable`
510  (i.e. if they would be saved with a checkpoint).
511
512  Args:
513    root_trackable: A `Trackable` object whose dependencies should be flattened.
514
515  Returns:
516    A flat list of objects.
517  """
518  return graph_view_lib.ObjectGraphView(root_trackable).list_objects()
519
520
521def gather_initializers(root_trackable):
522  """Traverse the object graph and find initialization ops.
523
524  Looks for `Trackable` objects which are dependencies of
525  `root_trackable` and which have an `initializer` property. Includes
526  initializers for slot variables only if the variable they are slotting for and
527  the optimizer are dependencies of `root_trackable` (i.e. if they would be
528  saved with a checkpoint).
529
530  Args:
531    root_trackable: A `Trackable` object to gather initializers for.
532
533  Returns:
534    A list of initialization ops.
535  """
536  trackable_objects = list_objects(root_trackable)
537  return [
538      c.initializer
539      for c in trackable_objects
540      if hasattr(c, "initializer") and c.initializer is not None
541  ]
542
543
544@tf_contextlib.contextmanager
545def capture_dependencies(template):
546  """Capture variables created within this scope as `Template` dependencies.
547
548  Requires that `template.variable_scope` is active.
549
550  This scope is intended as a compatibility measure, allowing a trackable
551  object to add dependencies on variables created in a block of code which is
552  not aware of object-based saving (and instead uses variable names
553  heavily). This is how `Template` objects add dependencies on variables and
554  sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
555
556  Args:
557    template: The `Template` object to register dependencies with.
558
559  Yields:
560    None (when used as a context manager).
561  """
562  name_prefix = template.variable_scope.name
563
564  def _trackable_custom_creator(next_creator,
565                                name,
566                                initial_value,
567                                trackable_parent=None,
568                                **kwargs):
569    """A variable creation hook which adds Trackable dependencies.
570
571    Set for example during a `Template`'s first wrapped function
572    execution. Ensures that (a) `template` depends on any trackable
573    objects using their own `capture_dependencies` scope inside this scope which
574    create variables, and (b) that any variables not in a more deeply nested
575    scope are added as dependencies directly.
576
577    The `trackable_parent` argument is passed between custom creators but
578    ignored when the variable object itself is created. This argument indicates
579    (if not `None`) that a more deeply nested scope has already added the
580    variable as a dependency, and that parent scopes should add a dependency on
581    that object rather than on the variable directly.
582
583    Args:
584      next_creator: See `variable_scope.variable_creator_scope`; the next
585        creator in the chain.
586      name: The (full, scope-influenced) name of the variable. The `name_prefix`
587        itself is stripped for the purposes of object-based dependency tracking,
588        but scopes opened within this scope are respected.
589      initial_value: See `variable_scope.variable_creator_scope`. Taken
590        explicitly so the argument can be re-named and used with
591        `Trackable._add_variable_with_custom_getter`.
592      trackable_parent: If not None, a more deeply nested trackable object and
593        its name prefix which were passed to `capture_dependencies` to add a
594        dependency on (rather than depending on the variable directly).
595      **kwargs: Passed through to the next creator.
596
597    Returns:
598      The output of `next_creator`: the fetched/created variable object.
599    """
600
601    def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
602      inner_kwargs.pop("name")  # Ignored; this is the scope-stripped name which
603      # we don't want to propagate.
604      return next_creator(initial_value=initializer, name=name, **inner_kwargs)
605
606    if name is not None and name.startswith(name_prefix):
607      scope_stripped_name = name[len(name_prefix) + 1:]
608      if not trackable_parent:
609        return template._add_variable_with_custom_getter(  # pylint: disable=protected-access
610            initializer=initial_value,
611            name=scope_stripped_name,
612            getter=_call_next_creator_renaming_initializer,
613            # Disable error checking for Trackable. Exceptions are instead
614            # raised if necessary when the object-based saver tries to
615            # save/restore the object.
616            overwrite=True,
617            trackable_parent=(template, name_prefix),
618            **kwargs)
619      else:
620        parent_object, parent_name_prefix = trackable_parent
621        template._track_trackable(  # pylint: disable=protected-access
622            parent_object,
623            name=parent_name_prefix[len(name_prefix) + 1:],
624            overwrite=True)
625    return next_creator(
626        name=name,
627        initial_value=initial_value,
628        trackable_parent=(template, name_prefix),
629        **kwargs)
630
631  with variable_scope.variable_creator_scope(_trackable_custom_creator):
632    yield
633
634
635class _LoadStatus(object):
636  """Abstract base for load status callbacks."""
637
638  @abc.abstractmethod
639  def assert_consumed(self):
640    """Raises an exception unless a non-trivial restoration has completed."""
641    pass
642
643  @abc.abstractmethod
644  def assert_existing_objects_matched(self):
645    """Raises an exception unless existing Python objects have been matched."""
646    pass
647
648  @abc.abstractmethod
649  def assert_nontrivial_match(self):
650    """Raises an exception if only the root object matched."""
651    pass
652
653  @abc.abstractmethod
654  def run_restore_ops(self, session=None):
655    """Runs restore ops from the checkpoint. Requires a valid checkpoint."""
656    pass
657
658  @abc.abstractmethod
659  def initialize_or_restore(self, session=None):
660    """Runs restore ops from the checkpoint, or initializes variables."""
661    pass
662
663  def expect_partial(self):
664    """Silence warnings about incomplete checkpoint restores."""
665    return self
666
667
668def streaming_restore(status, session=None):
669  """When graph building, runs restore ops as soon as they come in.
670
671  Args:
672    status: A _LoadStatus objects from an object-based saver's restore().
673      Streaming restore from name-based checkpoints is not currently supported.
674    session: A session to run new restore ops in.
675  """
676  if context.executing_eagerly():
677    # Streaming restore is the default/only behavior when executing eagerly.
678    return
679  if session is None:
680    session = get_session()
681  if isinstance(status, NameBasedSaverStatus):
682    raise NotImplementedError(
683        "Streaming restore not supported from name-based checkpoints when "
684        "graph building. File a feature request if this limitation bothers "
685        "you. As a workaround, consider either using tf.train.Checkpoint to "
686        "load name-based checkpoints or enabling eager execution.")
687  status.run_restore_ops(session=session)
688  # pylint: disable=protected-access
689  status._checkpoint.new_restore_ops_callback = (
690      lambda ops: session.run(ops, feed_dict=status._feed_dict))
691  # pylint: enable=protected-access
692
693
694def _objects_with_attributes(full_list):
695  """Filters out objects with no direct variable dependencies for assertions."""
696  return [o for o in full_list if o._gather_saveables_for_checkpoint()]  # pylint: disable=protected-access
697
698
699class CheckpointLoadStatus(_LoadStatus):
700  """Checks the status of checkpoint loading and manages restore ops.
701
702  Returned from `Saver.restore`. Since `restore` may defer the loading of values
703  in the checkpoint which don't yet have corresponding Python objects,
704  `CheckpointLoadStatus` provides a callback to verify that checkpoint loading
705  is complete (`assert_consumed`).
706
707  When graph building, `restore` does not run restore ops itself since their
708  creation may be deferred. The `run_restore_ops` method must be called once all
709  Python objects with values to restore have been created and added to the
710  dependency graph (this does not necessarily have to be the whole checkpoint;
711  calling `run_restore_ops` while `assert_consumed` fails is supported and will
712  partially restore the checkpoint).
713
714  See `Saver.restore` for usage examples.
715  """
716
717  def __init__(self, checkpoint, feed_dict, graph_view):
718    self._checkpoint = checkpoint
719    self._feed_dict = feed_dict
720    self._graph_view = graph_view
721    # Keep a reference to the root, since graph_view might only have a weakref.
722    self._root = graph_view.root
723
724  def assert_consumed(self):
725    """Asserts that all objects in the checkpoint have been created/matched.
726
727    Returns:
728      `self` for chaining.
729    Raises:
730      AssertionError: If there are any Python objects in the dependency graph
731        which have not been restored from this checkpoint or a later `restore`,
732        or if there are any checkpointed values which have not been matched to
733        Python objects.
734    """
735    pretty_printer = _ObjectGraphProtoPrettyPrinter(
736        self._checkpoint.object_graph_proto)
737    self.assert_existing_objects_matched()
738    for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
739      if not node.attributes:
740        # Only raise exceptions for the nodes with attributes themselves. Either
741        # they're ultimately not important, or they have a child with an
742        # attribute.
743        continue
744      trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
745      if trackable is None:
746        raise AssertionError("Unresolved object in checkpoint {}: {}"
747                             .format(pretty_printer.node_names[node_id], node))
748    if self._checkpoint.slot_restorations:
749      # Sanity check; this collection should be clear if everything has been
750      # restored.
751      raise AssertionError("Unresolved slot restorations: %s" %
752                           (self._checkpoint.slot_restorations,))
753    if self._checkpoint.unused_attributes:
754      unused_attribute_messages = []
755      for node_id, attribute in six.iteritems(
756          self._checkpoint.unused_attributes):
757        obj = self._checkpoint.object_by_proto_id[node_id]
758        unused_attribute_messages.append(
759            "{} ({}): {}"
760            .format(pretty_printer.node_names[node_id], obj, attribute))
761      raise AssertionError(
762          ("Unused attributes in these objects (the attributes exist in the "
763           "checkpoint but were not restored):\n{}")
764          .format("\n".join(unused_attribute_messages)))
765    return self
766
767  def assert_existing_objects_matched(self):
768    """Asserts that trackable Python objects have been matched.
769
770    Note that this is a weaker assertion than `assert_consumed`. It will only
771    fail for existing Python objects which are (transitive) dependencies of the
772    root object and which do not have an entry in the checkpoint.
773
774    It will not fail, for example, if a `tf.keras.Layer` object has not yet been
775    built and so has not created any `tf.Variable` objects.
776
777    Returns:
778      `self` for chaining.
779
780    Raises:
781      AssertionError: If a Python object exists in the transitive dependencies
782        of the root object but does not have a value in the checkpoint.
783    """
784    for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
785      trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
786      if (trackable is not None and
787          trackable._update_uid < self._checkpoint.restore_uid):  # pylint: disable=protected-access
788        raise AssertionError("Object not assigned a value from checkpoint: %s" %
789                             (node,))
790    for trackable_object in self._graph_view.list_objects():
791      # Remove data structures that do not contain any variables from
792      # restoration checks.
793      if (isinstance(trackable_object,
794                     data_structures.TrackableDataStructure) and
795          not trackable_object._checkpoint_dependencies):
796        continue
797      self._checkpoint.all_python_objects.add(trackable_object)
798    unused_python_objects = (
799        object_identity.ObjectIdentitySet(
800            _objects_with_attributes(
801                self._checkpoint.all_python_objects)) -
802        object_identity.ObjectIdentitySet(
803            self._checkpoint.object_by_proto_id.values()))
804    if unused_python_objects:
805      raise AssertionError(
806          ("Some Python objects were not bound to checkpointed values, likely "
807           "due to changes in the Python program: %s") %
808          (list(unused_python_objects),))
809    return self
810
811  def assert_nontrivial_match(self):
812    """Raises an exception if only the root object matched."""
813    for trackable_object in self._graph_view.list_objects():
814      self._checkpoint.all_python_objects.add(trackable_object)
815    if len(self._checkpoint.object_by_proto_id) <= 1:
816      unused_python_objects = (
817          object_identity.ObjectIdentitySet(
818              _objects_with_attributes(self._checkpoint.all_python_objects))
819          - object_identity.ObjectIdentitySet(
820              self._checkpoint.object_by_proto_id.values()))
821      if unused_python_objects:
822        raise AssertionError(
823            ("Nothing except the root object matched a checkpointed value. "
824             "Typically this means that the checkpoint does not match the "
825             "Python program. The following objects have no matching "
826             "checkpointed value: %s") % (list(unused_python_objects),))
827      else:
828        raise AssertionError(
829            "Nothing to load. No dependencies have been added to %s yet." %
830            (self._graph_view.root,))
831    return self
832
833  def run_restore_ops(self, session=None):
834    """Run operations to restore objects in the dependency graph."""
835    if context.executing_eagerly():
836      return  # Run eagerly
837    if session is None:
838      session = get_session()
839    session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
840
841  def initialize_or_restore(self, session=None):
842    """Run operations to initialize or restore objects in the dependency graph.
843
844    Any objects in the dependency graph which have initializers but are not in
845    the checkpoint will have those initializers run, unless those variables are
846    being restored by a later call to `tf.train.Checkpoint.restore()`.
847
848    This method has a sibling in `InitializationOnlyStatus` which instead
849    initializes variables. That type is returned if no checkpoint is specified
850    in `Saver.restore`.
851
852    Args:
853      session: The session to run init/restore ops in. If `None`, uses the
854        default session.
855    """
856    if context.executing_eagerly():
857      return  # Initialization and restoration ops are run eagerly
858    if session is None:
859      session = get_session()
860    all_objects = self._graph_view.list_objects()
861    already_initialized_objects = object_identity.ObjectIdentitySet(
862        self._checkpoint.object_by_proto_id.values())
863    initializers_for_non_restored_variables = [
864        c.initializer for c in all_objects
865        if hasattr(c, "initializer")
866        and c not in already_initialized_objects
867        and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1)
868             < self._checkpoint.restore_uid)]
869    self.run_restore_ops(session=session)
870    session.run(initializers_for_non_restored_variables)
871
872  def expect_partial(self):
873    """Silence warnings about incomplete checkpoint restores."""
874    self._checkpoint.expect_partial = True
875    return self
876
877
878class InitializationOnlyStatus(_LoadStatus):
879  """Returned from `Saver.restore` when no checkpoint has been specified.
880
881  Objects of this type have the same `assert_consumed` method as
882  `CheckpointLoadStatus`, but it always fails. However,
883  `initialize_or_restore` works on objects of both types, and will
884  initialize variables in `InitializationOnlyStatus` objects or restore them
885  otherwise.
886  """
887
888  def __init__(self, graph_view, restore_uid):
889    self._restore_uid = restore_uid
890    self._graph_view = graph_view
891    # Keep a reference to the root, since graph_view might only have a weakref.
892    self._root = graph_view.root
893
894  def assert_consumed(self):
895    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
896    raise AssertionError(
897        "No checkpoint specified (save_path=None); nothing is being restored.")
898
899  def assert_existing_objects_matched(self):
900    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
901    raise AssertionError(
902        "No checkpoint specified (save_path=None); nothing is being restored.")
903
904  def assert_nontrivial_match(self):
905    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
906    raise AssertionError(
907        "No checkpoint specified (save_path=None); nothing is being restored.")
908
909  def run_restore_ops(self, session=None):
910    """For consistency with `CheckpointLoadStatus`.
911
912    Use `initialize_or_restore` for initializing if no checkpoint was passed
913    to `Saver.restore` and restoring otherwise.
914
915    Args:
916      session: Not used.
917    """
918    raise AssertionError(
919        "No checkpoint specified, so no restore ops are available "
920        "(save_path=None to Saver.restore).")
921
922  def initialize_or_restore(self, session=None):
923    """Runs initialization ops for variables.
924
925    Objects which would be saved by `Saver.save` will be initialized, unless
926    those variables are being restored by a later call to
927    `tf.train.Checkpoint.restore()`.
928
929    This method does nothing when executing eagerly (initializers get run
930    eagerly).
931
932    Args:
933      session: The session to run initialization ops in. If `None`, uses the
934        default session.
935    """
936    if context.executing_eagerly():
937      return  # run eagerly
938    if session is None:
939      session = get_session()
940    trackable_objects = self._graph_view.list_objects()
941    initializers = [
942        c.initializer for c in trackable_objects
943        if hasattr(c, "initializer") and c.initializer is not None
944        and (getattr(c, "_update_uid", self._restore_uid - 1)
945             < self._restore_uid)]
946    session.run(initializers)
947
948
949_DEPRECATED_RESTORE_INSTRUCTIONS = (
950    "Restoring a name-based tf.train.Saver checkpoint using the object-based "
951    "restore API. This mode uses global names to match variables, and so is "
952    "somewhat fragile. It also adds new restore ops to the graph each time it "
953    "is called when graph building. Prefer re-encoding training checkpoints in "
954    "the object-based format: run save() on the object-based saver (the same "
955    "one this message is coming from) and use that checkpoint in the future.")
956
957
958class NameBasedSaverStatus(_LoadStatus):
959  """Status for loading a name-based training checkpoint."""
960
961  # Ideally this deprecation decorator would be on the class, but that
962  # interferes with isinstance checks.
963  @deprecation.deprecated(
964      date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
965  def __init__(self, checkpoint, graph_view):
966    self._checkpoint = checkpoint
967    self._graph_view = graph_view
968    self._optionally_restored = []
969    # Keep a reference to the root, since graph_view might only have a weakref.
970    self._root = graph_view.root
971
972  def add_to_optionally_restored(self, var):
973    """Add a variable to the list of optionally restored variables.
974
975    There are situations where certain variables should be ignored in assertions
976    such as assert_existing_objects_matched(). One example is that of a
977    checkpoint saved with train.Saver(), and restored with train.Checkpoint():
978    it is possible for the train.Saver() checkpoint to be missing the internal
979    `save_counter` variable, which we want to ignore on restore.
980
981    Args:
982      var: The variable to treat as optionally restored.
983    """
984    self._optionally_restored.append(var)
985
986  def assert_consumed(self):
987    """Raises an exception if any variables are unmatched."""
988    unused_attributes = list(self._checkpoint.unused_attributes.items())
989    unused_attributes = [
990        a for a in unused_attributes
991        if all(a[0] is not x for x in self._optionally_restored)
992    ]
993    if unused_attributes:
994      unused_attribute_strings = [
995          "\n    {}: {}".format(obj, attributes)
996          for obj, attributes in unused_attributes
997      ]
998      raise AssertionError(
999          "Some objects had attributes which were not restored:{}".format(
1000              "".join(unused_attribute_strings)))
1001    for trackable in self._graph_view.list_objects():
1002      # pylint: disable=protected-access
1003      trackable._maybe_initialize_trackable()
1004      if trackable._update_uid < self._checkpoint.restore_uid:
1005        raise AssertionError("Object not restored: %s" % (trackable,))
1006      # pylint: enable=protected-access
1007    return self
1008
1009  def assert_existing_objects_matched(self):
1010    """Raises an exception if currently created objects are unmatched."""
1011    # For name-based checkpoints there's no object information in the
1012    # checkpoint, so there's no distinction between
1013    # assert_existing_objects_matched and assert_consumed (and both are less
1014    # useful since we don't touch Python objects or Python state).
1015    return self.assert_consumed()
1016
1017  def assert_nontrivial_match(self):
1018    """Raises an exception if currently created objects are unmatched."""
1019    # For name-based checkpoints there's no object information in the
1020    # checkpoint, so there's no distinction between
1021    # assert_nontrivial_match and assert_consumed (and both are less
1022    # useful since we don't touch Python objects or Python state).
1023    return self.assert_consumed()
1024
1025  def _gather_saveable_objects(self):
1026    """Walk the object graph, using global names for SaveableObjects."""
1027    objects = self._graph_view.list_objects()
1028    saveable_objects = []
1029    for trackable in objects:
1030      # pylint: disable=protected-access
1031      trackable._maybe_initialize_trackable()
1032      if trackable._update_uid < self._checkpoint.restore_uid:
1033        trackable._update_uid = self._checkpoint.restore_uid
1034      else:
1035        continue
1036      # pylint: enable=protected-access
1037      saveable_objects.extend(
1038          self._checkpoint.globally_named_object_attributes(trackable))
1039    return saveable_objects
1040
1041  def run_restore_ops(self, session=None):
1042    """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
1043    if context.executing_eagerly():
1044      return  # Nothing to do, variables are restored on creation.
1045    if session is None:
1046      session = get_session()
1047    with ops.device("/cpu:0"):
1048      saveables = self._gather_saveable_objects()
1049      v1_saver_lib.Saver(saveables).restore(
1050          sess=session, save_path=self._checkpoint.save_path)
1051
1052  def initialize_or_restore(self, session=None):
1053    """Alias for `run_restore_ops`."""
1054    self.run_restore_ops(session=session)
1055
1056
1057class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
1058  """Pretends to be a session, inserts extra feeds on run()."""
1059
1060  def __init__(self, session, feed_additions):
1061    self._wrapped_session = session
1062    self._feed_additions = feed_additions
1063
1064  def run(self, fetches, feed_dict=None, **kwargs):
1065    if feed_dict is None:
1066      feed_dict = {}
1067    else:
1068      feed_dict = feed_dict.copy()
1069    feed_dict.update(self._feed_additions)
1070    return self._wrapped_session.run(
1071        fetches=fetches, feed_dict=feed_dict, **kwargs)
1072
1073
1074class TrackableSaver(object):
1075  """Saves and restores a `Trackable` object and its dependencies.
1076
1077  See `Trackable` for details of dependency management. `Saver` wraps
1078  `tf.compat.v1.train.Saver` for saving, including extra information about the
1079  graph of
1080  dependencies between Python objects. When restoring, it uses this information
1081  about the save-time dependency graph to more robustly match objects with their
1082  checkpointed values. When executing eagerly, it supports restoring variables
1083  on object creation (see `Saver.restore`).
1084
1085  Values in a checkpoint are mapped to `Trackable` Python objects
1086  (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the
1087  checkpoint was written. To avoid breaking existing checkpoints when modifying
1088  a class, dependency names (the names of attributes to which `Trackable`
1089  objects are assigned) may not change. These names are local to objects, in
1090  contrast to the `Variable.name`-based save/restore from
1091  `tf.compat.v1.train.Saver`, and
1092  so allow additional program transformations.
1093  """
1094
1095  def __init__(self, graph_view):
1096    """Configure saving.
1097
1098    Args:
1099      graph_view: A `GraphView` object containing a description of the object
1100        graph to save.
1101    """
1102    # The file prefix placeholder is created lazily when graph building (and not
1103    # at all when executing eagerly) to avoid creating ops in the constructor
1104    # (when they may never be necessary).
1105    self._file_prefix_placeholder = None
1106
1107    # Op caching for save
1108    self._object_graph_feed_tensor = None
1109    self._last_save_object_graph = None
1110    self._file_prefix_feed_tensor = None
1111    self._cached_save_operation = None
1112
1113    # Op caching for restore, shared between _CheckpointRestoreCoordinators
1114    self._restore_op_cache = {}
1115    self._graph_view = graph_view
1116
1117  def _gather_saveables(self, object_graph_tensor=None):
1118    """Wraps _serialize_object_graph to include the object graph proto."""
1119    (named_saveable_objects, graph_proto,
1120     feed_additions) = self._graph_view.serialize_object_graph()
1121    if object_graph_tensor is None:
1122      with ops.device("/cpu:0"):
1123        object_graph_tensor = constant_op.constant(
1124            graph_proto.SerializeToString(), dtype=dtypes.string)
1125    else:
1126      feed_additions.update(
1127          {object_graph_tensor: graph_proto.SerializeToString()})
1128    assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
1129    named_saveable_objects.append(
1130        base.NoRestoreSaveable(
1131            tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
1132    return named_saveable_objects, graph_proto, feed_additions
1133
1134  def _save_cached_when_graph_building(self,
1135                                       file_prefix,
1136                                       object_graph_tensor,
1137                                       options):
1138    """Create or retrieve save ops.
1139
1140    Args:
1141      file_prefix: The prefix for saved checkpoint files.
1142      object_graph_tensor: A `Tensor` to which the current object graph will be
1143        fed.
1144      options: `CheckpointOptions` object.
1145
1146    Returns:
1147      A two-element tuple with a filename tensor and a feed_dict of tensors to
1148      feed when running it (if graph building). The feed dict contains the
1149      current object graph and any Python state to be saved in the
1150      checkpoint. When executing eagerly only the first argument is meaningful.
1151    """
1152    (named_saveable_objects, graph_proto,
1153     feed_additions) = self._gather_saveables(
1154         object_graph_tensor=object_graph_tensor)
1155    if (self._last_save_object_graph != graph_proto
1156        # When executing eagerly, we need to re-create SaveableObjects each time
1157        # save() is called so they pick up new Tensors passed to their
1158        # constructors. That means the Saver needs to be copied with a new
1159        # var_list.
1160        or context.executing_eagerly() or ops.inside_function()):
1161      saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
1162      save_op = saver.save(file_prefix, options=options)
1163      with ops.device("/cpu:0"):
1164        with ops.control_dependencies([save_op]):
1165          self._cached_save_operation = array_ops.identity(file_prefix)
1166      self._last_save_object_graph = graph_proto
1167    return self._cached_save_operation, feed_additions
1168
1169  def save(self, file_prefix, checkpoint_number=None, session=None,
1170           options=None):
1171    """Save a training checkpoint.
1172
1173    The saved checkpoint includes variables created by this object and any
1174    Trackable objects it depends on at the time `Saver.save()` is called.
1175
1176    Args:
1177      file_prefix: A prefix to use for the checkpoint filenames
1178        (/path/to/directory/and_a_prefix). Names are generated based on this
1179        prefix and `checkpoint_number`, if provided.
1180      checkpoint_number: An integer variable or Tensor, used to number
1181        checkpoints. Typically this value is saved along with other variables in
1182        training checkpoints, which will happen automatically if it was created
1183        by `root_trackable` or one of its dependencies (via
1184        `Trackable._add_variable`).
1185      session: The session to evaluate variables in. Ignored when executing
1186        eagerly. If not provided when graph building, the default session is
1187        used.
1188      options: Optional `tf.train.CheckpointOptions` object.
1189
1190    Returns:
1191      The full path to the checkpoint.
1192    """
1193    options = options or checkpoint_options.CheckpointOptions()
1194    feed_dict = {}
1195    use_session = (not context.executing_eagerly() and
1196                   not ops.inside_function())
1197    if checkpoint_number:
1198      file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
1199    if use_session:
1200      if self._object_graph_feed_tensor is None:
1201        with ops.device("/cpu:0"):
1202          self._object_graph_feed_tensor = constant_op.constant(
1203              "", dtype=dtypes.string)
1204          self._file_prefix_feed_tensor = constant_op.constant(
1205              "", dtype=dtypes.string)
1206      object_graph_tensor = self._object_graph_feed_tensor
1207      file_prefix_tensor = self._file_prefix_feed_tensor
1208      feed_dict[file_prefix_tensor] = file_prefix
1209    else:
1210      with ops.device("/cpu:0"):
1211        file_prefix_tensor = constant_op.constant(
1212            file_prefix, dtype=dtypes.string)
1213      object_graph_tensor = None
1214
1215    file_io.recursive_create_dir(os.path.dirname(file_prefix))
1216    save_path, new_feed_additions = self._save_cached_when_graph_building(
1217        file_prefix_tensor, object_graph_tensor, options)
1218    if new_feed_additions:
1219      feed_dict.update(new_feed_additions)
1220    if not use_session:
1221      session = None
1222    elif session is None:
1223      session = get_session()
1224
1225    if session:
1226      return session.run(save_path, feed_dict=feed_dict)
1227    else:
1228      return save_path
1229
1230  def restore(self, save_path, options=None):
1231    """Restore a training checkpoint.
1232
1233    Restores `root_trackable` and any objects that it tracks
1234    (transitive). Either assigns values immediately if variables to restore have
1235    been created already, or defers restoration until the variables are
1236    created. Dependencies added to the `root_trackable` passed to the
1237    constructor after this call will be matched if they have a corresponding
1238    object in the checkpoint.
1239
1240    When building a graph, restorations are added to the graph but not run.
1241
1242    To disallow deferred loading, assert immediately that all checkpointed
1243    variables have been matched to variable objects:
1244
1245    ```python
1246    saver = Saver(root)
1247    saver.restore(path).assert_consumed()
1248    ```
1249
1250    An exception will be raised unless every object was matched and its
1251    variables already exist.
1252
1253    When graph building, `assert_consumed()` indicates that all of the restore
1254    ops which will be created for this checkpoint have been created. They can be
1255    run via the `run_restore_ops()` function of the status object:
1256
1257    ```python
1258    saver.restore(path).assert_consumed().run_restore_ops()
1259    ```
1260
1261    If the checkpoint has not been consumed completely, then the list of restore
1262    ops will grow as more objects are added to the dependency graph.
1263
1264    Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1265    method. There is no deferred loading, and names are used to match
1266    variables. No restore ops are created/run until `run_restore_ops()` or
1267    `initialize_or_restore()` are called on the returned status object, even
1268    when executing eagerly. Re-encode name-based checkpoints using this
1269    object-based `Saver.save` as soon as possible.
1270
1271    Args:
1272      save_path: The path to the checkpoint, as returned by `save` or
1273        `tf.train.latest_checkpoint`. If None (as when there is no latest
1274        checkpoint for `tf.train.latest_checkpoint` to return), returns an
1275        object which may run initializers for objects in the dependency graph.
1276        If the checkpoint was written by the name-based
1277        `tf.compat.v1.train.Saver`, names are used to match variables.
1278      options: Optional `tf.train.CheckpointOptions` object.
1279
1280    Returns:
1281      A load status object, which can be used to make assertions about the
1282      status of checkpoint restoration and run initialization/restore ops
1283      (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
1284      `save_path` is `None`).
1285
1286      If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
1287      object is returned which runs restore ops from a name-based saver.
1288    """
1289    options = options or checkpoint_options.CheckpointOptions()
1290    if save_path is None:
1291      return InitializationOnlyStatus(self._graph_view, ops.uid())
1292    reader = py_checkpoint_reader.NewCheckpointReader(save_path)
1293    graph_building = not context.executing_eagerly()
1294    if graph_building:
1295      dtype_map = None
1296    else:
1297      dtype_map = reader.get_variable_to_dtype_map()
1298    try:
1299      object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
1300    except errors_impl.NotFoundError:
1301      # The object graph proto does not exist in this checkpoint. Try the
1302      # name-based compatibility mode.
1303      restore_coordinator = _NameBasedRestoreCoordinator(
1304          save_path=save_path,
1305          dtype_map=dtype_map)
1306      if not graph_building:
1307        for existing_trackable in self._graph_view.list_objects():
1308          # pylint: disable=protected-access
1309          existing_trackable._maybe_initialize_trackable()
1310          existing_trackable._name_based_restores.add(restore_coordinator)
1311          existing_trackable._name_based_attribute_restore(restore_coordinator)
1312          # pylint: enable=protected-access
1313      return NameBasedSaverStatus(
1314          restore_coordinator,
1315          graph_view=self._graph_view)
1316
1317    if graph_building:
1318      if self._file_prefix_placeholder is None:
1319        with ops.device("/cpu:0"):
1320          self._file_prefix_placeholder = constant_op.constant("model")
1321      file_prefix_tensor = self._file_prefix_placeholder
1322      file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
1323    else:
1324      with ops.device("/cpu:0"):
1325        file_prefix_tensor = constant_op.constant(save_path)
1326      file_prefix_feed_dict = None
1327    object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1328    object_graph_proto.ParseFromString(object_graph_string)
1329    checkpoint = _CheckpointRestoreCoordinator(
1330        object_graph_proto=object_graph_proto,
1331        save_path=save_path,
1332        save_path_tensor=file_prefix_tensor,
1333        restore_op_cache=self._restore_op_cache,
1334        graph_view=self._graph_view,
1335        options=options)
1336    base.CheckpointPosition(
1337        checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
1338
1339    # Attached dependencies are not attached to the root, so should be restored
1340    # separately.
1341    if self._graph_view.attached_dependencies:
1342      for ref in self._graph_view.attached_dependencies:
1343        if ref.name == "root":
1344          # Root dependency is automatically added to attached dependencies --
1345          # this can be ignored since it maps back to the root object.
1346          continue
1347        proto_id = None
1348        # Find proto ID of attached dependency (if it is in the proto).
1349        for proto_ref in object_graph_proto.nodes[0].children:
1350          if proto_ref.local_name == ref.name:
1351            proto_id = proto_ref.node_id
1352            break
1353
1354        if proto_id in checkpoint.object_by_proto_id:
1355          # Object has already been restored. This can happen when there's an
1356          # indirect connection from the attached object to the root.
1357          continue
1358
1359        base.CheckpointPosition(
1360            checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref)
1361
1362    load_status = CheckpointLoadStatus(
1363        checkpoint,
1364        graph_view=self._graph_view,
1365        feed_dict=file_prefix_feed_dict)
1366    return load_status
1367
1368
1369def frozen_saver(root_trackable):
1370  """Creates a static `tf.compat.v1.train.Saver` from a trackable object.
1371
1372  The returned `Saver` saves object-based checkpoints, but these checkpoints
1373  will no longer reflect structural changes to the object graph, only changes to
1374  the values of `Variable`s added as dependencies of the root object before
1375  `freeze` was called.
1376
1377  `restore` works on the returned `Saver`, but requires that the object graph of
1378  the checkpoint being loaded exactly matches the object graph when `freeze` was
1379  called. This is in contrast the object-based restore performed by
1380  `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
1381  object graph and the current Python object graph.
1382
1383  Args:
1384    root_trackable: A trackable object to save.
1385
1386  Returns:
1387    A saver which saves object-based checkpoints for the object graph frozen at
1388    the time `frozen_saver` was called.
1389  """
1390  named_saveable_objects = graph_view_lib.ObjectGraphView(
1391      root_trackable).frozen_saveable_objects()
1392  return functional_saver.MultiDeviceSaver(named_saveable_objects)
1393
1394
1395def saver_with_op_caching(obj, attached_dependencies=None):
1396  """A TrackableSaver with a SaveableObject cache when graph building."""
1397  if context.executing_eagerly():
1398    saveables_cache = None
1399  else:
1400    saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
1401  return TrackableSaver(
1402      graph_view_lib.ObjectGraphView(
1403          weakref.ref(obj), saveables_cache=saveables_cache,
1404          attached_dependencies=attached_dependencies))
1405
1406
1407def _assert_trackable(obj):
1408  if not isinstance(
1409      obj, (base.Trackable, def_function.Function)):
1410    raise ValueError(
1411        "`Checkpoint` was expecting a trackable object (an object "
1412        "derived from `TrackableBase`), got {}. If you believe this "
1413        "object should be trackable (i.e. it is part of the "
1414        "TensorFlow Python API and manages state), please open an issue."
1415        .format(obj))
1416
1417
1418# Mentions graph building / Sessions. The v2 version is below.
1419@tf_export(v1=["train.Checkpoint"])
1420class CheckpointV1(tracking.AutoTrackable):
1421  """Groups trackable objects, saving and restoring them.
1422
1423  `Checkpoint`'s constructor accepts keyword arguments whose values are types
1424  that contain trackable state, such as `tf.compat.v1.train.Optimizer`
1425  implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
1426  `tf.keras.Model` implementations. It saves these values with a checkpoint, and
1427  maintains a `save_counter` for numbering checkpoints.
1428
1429  Example usage when graph building:
1430
1431  ```python
1432  import tensorflow as tf
1433  import os
1434
1435  checkpoint_directory = "/tmp/training_checkpoints"
1436  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1437
1438  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1439  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1440  train_op = optimizer.minimize( ... )
1441  status.assert_consumed()  # Optional sanity checks.
1442  with tf.compat.v1.Session() as session:
1443    # Use the Session to restore variables, or initialize them if
1444    # tf.train.latest_checkpoint returned None.
1445    status.initialize_or_restore(session)
1446    for _ in range(num_training_steps):
1447      session.run(train_op)
1448    checkpoint.save(file_prefix=checkpoint_prefix)
1449  ```
1450
1451  Example usage with eager execution enabled:
1452
1453  ```python
1454  import tensorflow as tf
1455  import os
1456
1457  tf.compat.v1.enable_eager_execution()
1458
1459  checkpoint_directory = "/tmp/training_checkpoints"
1460  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1461
1462  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1463  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1464  for _ in range(num_training_steps):
1465    optimizer.minimize( ... )  # Variables will be restored on creation.
1466  status.assert_consumed()  # Optional sanity checks.
1467  checkpoint.save(file_prefix=checkpoint_prefix)
1468  ```
1469
1470  `Checkpoint.save` and `Checkpoint.restore` write and read object-based
1471  checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
1472  `variable.name` based checkpoints. Object-based checkpointing saves a graph of
1473  dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
1474  etc.) with named edges, and this graph is used to match variables when
1475  restoring a checkpoint. It can be more robust to changes in the Python
1476  program, and helps to support restore-on-create for variables when executing
1477  eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
1478  code.
1479
1480  `Checkpoint` objects have dependencies on the objects passed as keyword
1481  arguments to their constructors, and each dependency is given a name that is
1482  identical to the name of the keyword argument for which it was created.
1483  TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
1484  dependencies on their variables (e.g. "kernel" and "bias" for
1485  `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
1486  dependencies easy in user-defined classes, since `Model` hooks into attribute
1487  assignment. For example:
1488
1489  ```python
1490  class Regress(tf.keras.Model):
1491
1492    def __init__(self):
1493      super(Regress, self).__init__()
1494      self.input_transform = tf.keras.layers.Dense(10)
1495      # ...
1496
1497    def call(self, inputs):
1498      x = self.input_transform(inputs)
1499      # ...
1500  ```
1501
1502  This `Model` has a dependency named "input_transform" on its `Dense` layer,
1503  which in turn depends on its variables. As a result, saving an instance of
1504  `Regress` using `tf.train.Checkpoint` will also save all the variables created
1505  by the `Dense` layer.
1506
1507  When variables are assigned to multiple workers, each worker writes its own
1508  section of the checkpoint. These sections are then merged/re-indexed to behave
1509  as a single checkpoint. This avoids copying all variables to one worker, but
1510  does require that all workers see a common filesystem.
1511
1512  While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
1513  same format, note that the root of the resulting checkpoint is the object the
1514  save method is attached to. This means saving a `tf.keras.Model` using
1515  `save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
1516  attached (or vice versa) will not match the `Model`'s variables. See the
1517  [guide to training
1518  checkpoints](https://www.tensorflow.org/guide/checkpoint) for
1519  details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
1520  training checkpoints.
1521
1522  Attributes:
1523    save_counter: Incremented when `save()` is called. Used to number
1524      checkpoints.
1525  """
1526
1527  def __init__(self, **kwargs):
1528    """Group objects into a training checkpoint.
1529
1530    Args:
1531      **kwargs: Keyword arguments are set as attributes of this object, and are
1532        saved with the checkpoint. Values must be trackable objects.
1533
1534    Raises:
1535      ValueError: If objects in `kwargs` are not trackable.
1536    """
1537    super(CheckpointV1, self).__init__()
1538    for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
1539      setattr(self, k, v)
1540      if not isinstance(
1541          getattr(self, k), (base.Trackable, def_function.Function)):
1542        raise ValueError(
1543            ("`Checkpoint` was expecting a trackable object (an object "
1544             "derived from `TrackableBase`), got %s. If you believe this "
1545             "object should be trackable (i.e. it is part of the "
1546             "TensorFlow Python API and manages state), please open an issue.")
1547            % (v,))
1548    self._save_counter = None  # Created lazily for restore-on-create.
1549    self._save_assign_op = None
1550    self._saver = saver_with_op_caching(self)
1551
1552  def _maybe_create_save_counter(self):
1553    """Create a save counter if it does not yet exist."""
1554    if self._save_counter is None:
1555      # Initialized to 0 and incremented before saving.
1556      with ops.device("/cpu:0"):
1557        # add_variable creates a dependency named "save_counter"; NoDependency
1558        # prevents creating a second dependency named "_save_counter".
1559        self._save_counter = data_structures.NoDependency(
1560            add_variable(
1561                self,
1562                name="save_counter",
1563                initializer=0,
1564                dtype=dtypes.int64,
1565                trainable=False))
1566
1567  def write(self, file_prefix, session=None):
1568    """Writes a training checkpoint.
1569
1570    The checkpoint includes variables created by this object and any
1571    trackable objects it depends on at the time `Checkpoint.write()` is
1572    called.
1573
1574    `write` does not number checkpoints, increment `save_counter`, or update the
1575    metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1576    use by higher level checkpoint management utilities. `save` provides a very
1577    basic implementation of these features.
1578
1579    Args:
1580      file_prefix: A prefix to use for the checkpoint filenames
1581        (/path/to/directory/and_a_prefix).
1582      session: The session to evaluate variables in. Ignored when executing
1583        eagerly. If not provided when graph building, the default session is
1584        used.
1585
1586    Returns:
1587      The full path to the checkpoint (i.e. `file_prefix`).
1588    """
1589    output = self._saver.save(file_prefix=file_prefix, session=session)
1590    if tensor_util.is_tf_type(output):
1591      if context.executing_eagerly():
1592        return compat.as_str(output.numpy())
1593      else:
1594        # Function building
1595        return output
1596    else:
1597      # Graph + Session, so we already session.ran it.
1598      return compat.as_str(output)
1599
1600  @property
1601  def save_counter(self):
1602    """An integer variable which starts at zero and is incremented on save.
1603
1604    Used to number checkpoints.
1605
1606    Returns:
1607      The save counter variable.
1608    """
1609    self._maybe_create_save_counter()
1610    return self._save_counter
1611
1612  def save(self, file_prefix, session=None):
1613    """Saves a training checkpoint and provides basic checkpoint management.
1614
1615    The saved checkpoint includes variables created by this object and any
1616    trackable objects it depends on at the time `Checkpoint.save()` is
1617    called.
1618
1619    `save` is a basic convenience wrapper around the `write` method,
1620    sequentially numbering checkpoints using `save_counter` and updating the
1621    metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
1622    management, for example garbage collection and custom numbering, may be
1623    provided by other utilities which also wrap `write`
1624    (`tf.train.CheckpointManager` for example).
1625
1626    Args:
1627      file_prefix: A prefix to use for the checkpoint filenames
1628        (/path/to/directory/and_a_prefix). Names are generated based on this
1629        prefix and `Checkpoint.save_counter`.
1630      session: The session to evaluate variables in. Ignored when executing
1631        eagerly. If not provided when graph building, the default session is
1632        used.
1633
1634    Returns:
1635      The full path to the checkpoint.
1636    """
1637    graph_building = not context.executing_eagerly()
1638    if graph_building:
1639      if ops.inside_function():
1640        raise NotImplementedError(
1641            "Calling tf.train.Checkpoint.save() from a function is not "
1642            "supported, as save() modifies saving metadata in ways not "
1643            "supported by TensorFlow Operations. Consider using "
1644            "tf.train.Checkpoint.write(), a lower-level API which does not "
1645            "update metadata. tf.train.latest_checkpoint and related APIs will "
1646            "not see this checkpoint.")
1647      if session is None:
1648        session = get_session()
1649      if self._save_counter is None:
1650        # When graph building, if this is a new save counter variable then it
1651        # needs to be initialized before assign_add. This is only an issue if
1652        # restore() has not been called first.
1653        session.run(self.save_counter.initializer)
1654    if not graph_building or self._save_assign_op is None:
1655      with ops.colocate_with(self.save_counter):
1656        assign_op = self.save_counter.assign_add(1, read_value=True)
1657      if graph_building:
1658        self._save_assign_op = data_structures.NoDependency(assign_op)
1659    if graph_building:
1660      checkpoint_number = session.run(self._save_assign_op)
1661    else:
1662      checkpoint_number = assign_op.numpy()
1663    file_path = self.write(
1664        "%s-%d" % (file_prefix, checkpoint_number), session=session)
1665    checkpoint_management.update_checkpoint_state_internal(
1666        save_dir=os.path.dirname(file_prefix),
1667        model_checkpoint_path=file_path,
1668        all_model_checkpoint_paths=[file_path],
1669        save_relative_paths=True)
1670    return file_path
1671
1672  def restore(self, save_path):
1673    """Restore a training checkpoint.
1674
1675    Restores this `Checkpoint` and any objects it depends on.
1676
1677    When executing eagerly, either assigns values immediately if variables to
1678    restore have been created already, or defers restoration until the variables
1679    are created. Dependencies added after this call will be matched if they have
1680    a corresponding object in the checkpoint (the restore request will queue in
1681    any trackable object waiting for the expected dependency to be added).
1682
1683    When graph building, restoration ops are added to the graph but not run
1684    immediately.
1685
1686    To ensure that loading is complete and no more assignments will take place,
1687    use the `assert_consumed()` method of the status object returned by
1688    `restore`:
1689
1690    ```python
1691    checkpoint = tf.train.Checkpoint( ... )
1692    checkpoint.restore(path).assert_consumed()
1693    ```
1694
1695    An exception will be raised if any Python objects in the dependency graph
1696    were not found in the checkpoint, or if any checkpointed values do not have
1697    a matching Python object.
1698
1699    When graph building, `assert_consumed()` indicates that all of the restore
1700    ops that will be created for this checkpoint have been created. They can be
1701    run via the `run_restore_ops()` method of the status object:
1702
1703    ```python
1704    checkpoint.restore(path).assert_consumed().run_restore_ops()
1705    ```
1706
1707    If the checkpoint has not been consumed completely, then the list of restore
1708    ops will grow as more objects are added to the dependency graph.
1709
1710    Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1711    method. Names are used to match variables. No restore ops are created/run
1712    until `run_restore_ops()` or `initialize_or_restore()` are called on the
1713    returned status object when graph building, but there is restore-on-creation
1714    when executing eagerly. Re-encode name-based checkpoints using
1715    `tf.train.Checkpoint.save` as soon as possible.
1716
1717    Args:
1718      save_path: The path to the checkpoint, as returned by `save` or
1719        `tf.train.latest_checkpoint`. If None (as when there is no latest
1720        checkpoint for `tf.train.latest_checkpoint` to return), returns an
1721        object which may run initializers for objects in the dependency graph.
1722        If the checkpoint was written by the name-based
1723        `tf.compat.v1.train.Saver`, names are used to match variables.
1724
1725    Returns:
1726      A load status object, which can be used to make assertions about the
1727      status of a checkpoint restoration and run initialization/restore ops.
1728
1729      The returned status object has the following methods:
1730
1731      * `assert_consumed()`:
1732          Raises an exception if any variables are unmatched: either
1733          checkpointed values which don't have a matching Python object or
1734          Python objects in the dependency graph with no values in the
1735          checkpoint. This method returns the status object, and so may be
1736          chained with `initialize_or_restore` or `run_restore_ops`.
1737
1738      * `assert_existing_objects_matched()`:
1739          Raises an exception if any existing Python objects in the dependency
1740          graph are unmatched. Unlike `assert_consumed`, this assertion will
1741          pass if values in the checkpoint have no corresponding Python
1742          objects. For example a `tf.keras.Layer` object which has not yet been
1743          built, and so has not created any variables, will pass this assertion
1744          but fail `assert_consumed`. Useful when loading part of a larger
1745          checkpoint into a new Python program, e.g. a training checkpoint with
1746          a `tf.compat.v1.train.Optimizer` was saved but only the state required
1747          for
1748          inference is being loaded. This method returns the status object, and
1749          so may be chained with `initialize_or_restore` or `run_restore_ops`.
1750
1751      * `assert_nontrivial_match()`: Asserts that something aside from the root
1752          object was matched. This is a very weak assertion, but is useful for
1753          sanity checking in library code where objects may exist in the
1754          checkpoint which haven't been created in Python and some Python
1755          objects may not have a checkpointed value.
1756
1757      * `expect_partial()`: Silence warnings about incomplete checkpoint
1758          restores. Warnings are otherwise printed for unused parts of the
1759          checkpoint file or object when the `Checkpoint` object is deleted
1760          (often at program shutdown).
1761
1762      * `initialize_or_restore(session=None)`:
1763          When graph building, runs variable initializers if `save_path` is
1764          `None`, but otherwise runs restore operations. If no `session` is
1765          explicitly specified, the default session is used. No effect when
1766          executing eagerly (variables are initialized or restored eagerly).
1767
1768      * `run_restore_ops(session=None)`:
1769          When graph building, runs restore operations. If no `session` is
1770          explicitly specified, the default session is used. No effect when
1771          executing eagerly (restore operations are run eagerly). May only be
1772          called when `save_path` is not `None`.
1773    """
1774    status = self._saver.restore(save_path=save_path)
1775    # Create the save counter now so it gets initialized with other variables
1776    # when graph building. Creating it earlier would lead to errors when using,
1777    # say, train.Saver() to save the model before initializing it.
1778    self._maybe_create_save_counter()
1779    if isinstance(status, NameBasedSaverStatus):
1780      status.add_to_optionally_restored(self.save_counter)
1781    return status
1782
1783
1784@tf_export("train.Checkpoint", v1=[])
1785class Checkpoint(tracking.AutoTrackable):
1786  """Manages saving/restoring trackable values to disk.
1787
1788  TensorFlow objects may contain trackable state, such as `tf.Variable`s,
1789  `tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators,
1790  `tf.keras.Layer` implementations, or  `tf.keras.Model` implementations.
1791  These are called **trackable objects**.
1792
1793  A `Checkpoint` object can be constructed to save either a single or group of
1794  trackable objects to a checkpoint file. It maintains a `save_counter` for
1795  numbering checkpoints.
1796
1797  Example:
1798
1799  ```python
1800  model = tf.keras.Model(...)
1801  checkpoint = tf.train.Checkpoint(model)
1802
1803  # Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
1804  # checkpoint.save is called, the save counter is increased.
1805  save_path = checkpoint.save('/tmp/training_checkpoints')
1806
1807  # Restore the checkpointed values to the `model` object.
1808  checkpoint.restore(save_path)
1809  ```
1810
1811  Example 2:
1812
1813  ```python
1814  import tensorflow as tf
1815  import os
1816
1817  checkpoint_directory = "/tmp/training_checkpoints"
1818  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1819
1820  # Create a Checkpoint that will manage two objects with trackable state,
1821  # one we name "optimizer" and the other we name "model".
1822  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1823  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1824  for _ in range(num_training_steps):
1825    optimizer.minimize( ... )  # Variables will be restored on creation.
1826  status.assert_consumed()  # Optional sanity checks.
1827  checkpoint.save(file_prefix=checkpoint_prefix)
1828  ```
1829
1830  `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based
1831  checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
1832  writes and
1833  reads `variable.name` based checkpoints. Object-based checkpointing saves a
1834  graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
1835  `Variable`s, etc.) with named edges, and this graph is used to match variables
1836  when restoring a checkpoint. It can be more robust to changes in the Python
1837  program, and helps to support restore-on-create for variables.
1838
1839  `Checkpoint` objects have dependencies on the objects passed as keyword
1840  arguments to their constructors, and each dependency is given a name that is
1841  identical to the name of the keyword argument for which it was created.
1842  TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
1843  dependencies on their own variables (e.g. "kernel" and "bias" for
1844  `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
1845  dependencies easy in user-defined classes, since `Model` hooks into attribute
1846  assignment. For example:
1847
1848  ```python
1849  class Regress(tf.keras.Model):
1850
1851    def __init__(self):
1852      super(Regress, self).__init__()
1853      self.input_transform = tf.keras.layers.Dense(10)
1854      # ...
1855
1856    def call(self, inputs):
1857      x = self.input_transform(inputs)
1858      # ...
1859  ```
1860
1861  This `Model` has a dependency named "input_transform" on its `Dense` layer,
1862  which in turn depends on its variables. As a result, saving an instance of
1863  `Regress` using `tf.train.Checkpoint` will also save all the variables created
1864  by the `Dense` layer.
1865
1866  When variables are assigned to multiple workers, each worker writes its own
1867  section of the checkpoint. These sections are then merged/re-indexed to behave
1868  as a single checkpoint. This avoids copying all variables to one worker, but
1869  does require that all workers see a common filesystem.
1870
1871  This function differs slightly from the Keras Model `save_weights` function.
1872  `tf.keras.Model.save_weights` creates a checkpoint file with the name
1873  specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints,
1874  using `filepath` as the prefix for the checkpoint file names. Aside from this,
1875  `model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent.
1876
1877  See the [guide to training
1878  checkpoints](https://www.tensorflow.org/guide/checkpoint) for
1879  details.
1880
1881  Attributes:
1882    save_counter: Incremented when `save()` is called. Used to number
1883      checkpoints.
1884  """
1885
1886  def __init__(self, root=None, **kwargs):
1887    """Creates a training checkpoint for a single or group of objects.
1888
1889    Args:
1890      root: The root object to checkpoint.
1891      **kwargs: Keyword arguments are set as attributes of this object, and are
1892        saved with the checkpoint. Values must be trackable objects.
1893
1894    Raises:
1895      ValueError: If `root` or the objects in `kwargs` are not trackable. A
1896        `ValueError` is also raised if the `root` object tracks different
1897        objects from the ones listed in attributes in kwargs (e.g.
1898        `root.child = A` and `tf.train.Checkpoint(root, child=B)` are
1899        incompatible).
1900
1901    """
1902    super(Checkpoint, self).__init__()
1903
1904    saver_root = self
1905    attached_dependencies = None
1906    self._save_counter = None  # Created lazily for restore-on-create.
1907    self._save_assign_op = None
1908
1909    if root:
1910      _assert_trackable(root)
1911      saver_root = root
1912      attached_dependencies = []
1913
1914      # All keyword arguments (including root itself) are set as children
1915      # of root.
1916      kwargs["root"] = root
1917      root._maybe_initialize_trackable()
1918
1919      self._save_counter = data_structures.NoDependency(
1920          root._lookup_dependency("save_counter"))
1921      self._root = data_structures.NoDependency(root)
1922
1923    for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
1924      setattr(self, k, v)
1925
1926      # Call getattr instead of directly using v because setattr converts
1927      # v to a Trackable data structure when v is a list/dict/tuple.
1928      converted_v = getattr(self, k)
1929      _assert_trackable(converted_v)
1930
1931      if root:
1932        # Make sure that root doesn't already have dependencies with these names
1933        child = root._lookup_dependency(k)
1934        if child is None:
1935          attached_dependencies.append(base.TrackableReference(k, converted_v))
1936        elif child != converted_v:
1937          raise ValueError(
1938              "Cannot create a Checkpoint with keyword argument {name} if "
1939              "root.{name} already exists.".format(name=k))
1940
1941    self._saver = saver_with_op_caching(saver_root, attached_dependencies)
1942    self._attached_dependencies = data_structures.NoDependency(
1943        attached_dependencies)
1944
1945  def _maybe_create_save_counter(self):
1946    """Create a save counter if it does not yet exist."""
1947    if self._save_counter is None:
1948      # Initialized to 0 and incremented before saving.
1949      with ops.device("/cpu:0"):
1950        # add_variable creates a dependency named "save_counter"; NoDependency
1951        # prevents creating a second dependency named "_save_counter".
1952        self._save_counter = data_structures.NoDependency(
1953            add_variable(
1954                self,
1955                name="save_counter",
1956                initializer=0,
1957                dtype=dtypes.int64,
1958                trainable=False))
1959        if self._attached_dependencies is not None:
1960          self._attached_dependencies.append(
1961              base.TrackableReference("save_counter", self._save_counter))
1962          # When loading a checkpoint, the save counter is created after
1963          # the checkpoint has been loaded, so it must be handled in a deferred
1964          # manner.
1965          restore = self.root._deferred_dependencies.pop("save_counter", ())  # pylint: disable=protected-access
1966          if restore:
1967            restore[0].restore(self._save_counter)
1968
1969  def write(self, file_prefix, options=None):
1970    """Writes a training checkpoint.
1971
1972    The checkpoint includes variables created by this object and any
1973    trackable objects it depends on at the time `Checkpoint.write()` is
1974    called.
1975
1976    `write` does not number checkpoints, increment `save_counter`, or update the
1977    metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1978    use by higher level checkpoint management utilities. `save` provides a very
1979    basic implementation of these features.
1980
1981    Checkpoints written with `write` must be read with `read`.
1982
1983    Example usage:
1984
1985    ```
1986    step = tf.Variable(0, name="step")
1987    checkpoint = tf.Checkpoint(step=step)
1988    checkpoint.write("/tmp/ckpt")
1989
1990    # Later, read the checkpoint with read()
1991    checkpoint.read("/tmp/ckpt").assert_consumed()
1992
1993    # You can also pass options to write() and read(). For example this
1994    # runs the IO ops on the localhost:
1995    options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
1996    checkpoint.write("/tmp/ckpt", options=options)
1997
1998    # Later, read the checkpoint with read()
1999    checkpoint.read("/tmp/ckpt", options=options).assert_consumed()
2000    ```
2001
2002    Args:
2003      file_prefix: A prefix to use for the checkpoint filenames
2004        (/path/to/directory/and_a_prefix).
2005      options: Optional `tf.train.CheckpointOptions` object.
2006
2007    Returns:
2008      The full path to the checkpoint (i.e. `file_prefix`).
2009    """
2010    options = options or checkpoint_options.CheckpointOptions()
2011    output = self._saver.save(file_prefix=file_prefix, options=options)
2012    if tensor_util.is_tf_type(output):
2013      if context.executing_eagerly():
2014        return compat.as_str(output.numpy())
2015      else:
2016        # Function building
2017        return output
2018    else:
2019      # Graph + Session, so we already session.ran it.
2020      return compat.as_str(output)
2021
2022  @property
2023  def save_counter(self):
2024    """An integer variable which starts at zero and is incremented on save.
2025
2026    Used to number checkpoints.
2027
2028    Returns:
2029      The save counter variable.
2030    """
2031    self._maybe_create_save_counter()
2032    return self._save_counter
2033
2034  def save(self, file_prefix, options=None):
2035    """Saves a training checkpoint and provides basic checkpoint management.
2036
2037    The saved checkpoint includes variables created by this object and any
2038    trackable objects it depends on at the time `Checkpoint.save()` is
2039    called.
2040
2041    `save` is a basic convenience wrapper around the `write` method,
2042    sequentially numbering checkpoints using `save_counter` and updating the
2043    metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
2044    management, for example garbage collection and custom numbering, may be
2045    provided by other utilities which also wrap `write` and `read`.
2046    (`tf.train.CheckpointManager` for example).
2047
2048    ```
2049    step = tf.Variable(0, name="step")
2050    checkpoint = tf.Checkpoint(step=step)
2051    checkpoint.save("/tmp/ckpt")
2052
2053    # Later, read the checkpoint with restore()
2054    checkpoint.restore("/tmp/ckpt").assert_consumed()
2055
2056    # You can also pass options to save() and restore(). For example this
2057    # runs the IO ops on the localhost:
2058    options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
2059    checkpoint.save("/tmp/ckpt", options=options)
2060
2061    # Later, read the checkpoint with restore()
2062    checkpoint.restore("/tmp/ckpt", options=options).assert_consumed()
2063    ```
2064
2065    Args:
2066      file_prefix: A prefix to use for the checkpoint filenames
2067        (/path/to/directory/and_a_prefix). Names are generated based on this
2068        prefix and `Checkpoint.save_counter`.
2069      options: Optional `tf.train.CheckpointOptions` object.
2070
2071    Returns:
2072      The full path to the checkpoint.
2073    """
2074    options = options or checkpoint_options.CheckpointOptions()
2075    graph_building = not context.executing_eagerly()
2076    if graph_building:
2077      if ops.inside_function():
2078        raise NotImplementedError(
2079            "Calling tf.train.Checkpoint.save() from a function is not "
2080            "supported, as save() modifies saving metadata in ways not "
2081            "supported by TensorFlow Operations. Consider using "
2082            "tf.train.Checkpoint.write(), a lower-level API which does not "
2083            "update metadata. tf.train.latest_checkpoint and related APIs will "
2084            "not see this checkpoint.")
2085      session = get_session()
2086      if self._save_counter is None:
2087        # When graph building, if this is a new save counter variable then it
2088        # needs to be initialized before assign_add. This is only an issue if
2089        # restore() has not been called first.
2090        session.run(self.save_counter.initializer)
2091    if not graph_building or self._save_assign_op is None:
2092      with ops.colocate_with(self.save_counter):
2093        assign_op = self.save_counter.assign_add(1, read_value=True)
2094      if graph_building:
2095        self._save_assign_op = data_structures.NoDependency(assign_op)
2096    if graph_building:
2097      checkpoint_number = session.run(self._save_assign_op)
2098    else:
2099      checkpoint_number = assign_op.numpy()
2100    file_path = self.write("%s-%d" % (file_prefix, checkpoint_number),
2101                           options=options)
2102    checkpoint_management.update_checkpoint_state_internal(
2103        save_dir=os.path.dirname(file_prefix),
2104        model_checkpoint_path=file_path,
2105        all_model_checkpoint_paths=[file_path],
2106        save_relative_paths=True)
2107    return file_path
2108
2109  def read(self, save_path, options=None):
2110    """Reads a training checkpoint written with `write`.
2111
2112    Reads this `Checkpoint` and any objects it depends on.
2113
2114    This method is just like `restore()` but does not expect the `save_counter`
2115    variable in the checkpoint. It only restores the objects that the checkpoint
2116    already depends on.
2117
2118    The method is primarily intended for use by higher level checkpoint
2119    management utilities that use `write()` instead of `save()` and have their
2120    own mechanisms to number and track checkpoints.
2121
2122    Example usage:
2123
2124    ```python
2125    # Create a checkpoint with write()
2126    ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
2127    path = ckpt.write('/tmp/my_checkpoint')
2128
2129    # Later, load the checkpoint with read()
2130    # With restore() assert_consumed() would have failed.
2131    checkpoint.read(path).assert_consumed()
2132
2133    # You can also pass options to read(). For example this
2134    # runs the IO ops on the localhost:
2135    options = tf.train.CheckpointOptions(
2136        experimental_io_device="/job:localhost")
2137    checkpoint.read(path, options=options)
2138    ```
2139
2140    Args:
2141      save_path: The path to the checkpoint as returned by `write`.
2142      options: Optional `tf.train.CheckpointOptions` object.
2143
2144    Returns:
2145      A load status object, which can be used to make assertions about the
2146      status of a checkpoint restoration.  See `restore` for details.
2147    """
2148    options = options or checkpoint_options.CheckpointOptions()
2149    return self._saver.restore(save_path=save_path, options=options)
2150
2151  def restore(self, save_path, options=None):
2152    """Restores a training checkpoint.
2153
2154    Restores this `Checkpoint` and any objects it depends on.
2155
2156    This method is intended to be used to load checkpoints created by `save()`.
2157    For checkpoints created by `write()` use the `read()` method which does not
2158    expect the `save_counter` variable added by `save()`.
2159
2160    `restore()` either assigns values immediately if variables to restore have
2161    been created already, or defers restoration until the variables are
2162    created. Dependencies added after this call will be matched if they have a
2163    corresponding object in the checkpoint (the restore request will queue in
2164    any trackable object waiting for the expected dependency to be added).
2165
2166    To ensure that loading is complete and no more assignments will take place,
2167    use the `assert_consumed()` method of the status object returned by
2168    `restore()`:
2169
2170    ```python
2171    checkpoint = tf.train.Checkpoint( ... )
2172    checkpoint.restore(path).assert_consumed()
2173
2174    # You can additionally pass options to restore():
2175    options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
2176    checkpoint.restore(path, options=options).assert_consumed()
2177    ```
2178
2179    An exception will be raised if any Python objects in the dependency graph
2180    were not found in the checkpoint, or if any checkpointed values do not have
2181    a matching Python object.
2182
2183    Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
2184    loaded using this method. Names are used to match variables. Re-encode
2185    name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible.
2186
2187    **Loading from SavedModel checkpoints**
2188
2189    To load values from a SavedModel, just pass the SavedModel directory
2190    to checkpoint.restore:
2191
2192    ```python
2193    model = tf.keras.Model(...)
2194    tf.saved_model.save(model, path)  # or model.save(path, save_format='tf')
2195
2196    checkpoint = tf.train.Checkpoint(model)
2197    checkpoint.restore(path).expect_partial()
2198    ```
2199
2200    This example calls `expect_partial()` on the loaded status, since
2201    SavedModels saved from Keras often generates extra keys in the checkpoint.
2202    Otherwise, the program prints a lot of warnings about unused keys at exit
2203    time.
2204
2205    Args:
2206      save_path: The path to the checkpoint, as returned by `save` or
2207        `tf.train.latest_checkpoint`. If the checkpoint was written by the
2208        name-based `tf.compat.v1.train.Saver`, names are used to match
2209        variables. This path may also be a SavedModel directory.
2210      options: Optional `tf.train.CheckpointOptions` object.
2211
2212    Returns:
2213      A load status object, which can be used to make assertions about the
2214      status of a checkpoint restoration.
2215
2216      The returned status object has the following methods:
2217
2218      * `assert_consumed()`:
2219          Raises an exception if any variables are unmatched: either
2220          checkpointed values which don't have a matching Python object or
2221          Python objects in the dependency graph with no values in the
2222          checkpoint. This method returns the status object, and so may be
2223          chained with other assertions.
2224
2225      * `assert_existing_objects_matched()`:
2226          Raises an exception if any existing Python objects in the dependency
2227          graph are unmatched. Unlike `assert_consumed`, this assertion will
2228          pass if values in the checkpoint have no corresponding Python
2229          objects. For example a `tf.keras.Layer` object which has not yet been
2230          built, and so has not created any variables, will pass this assertion
2231          but fail `assert_consumed`. Useful when loading part of a larger
2232          checkpoint into a new Python program, e.g. a training checkpoint with
2233          a `tf.compat.v1.train.Optimizer` was saved but only the state required
2234          for
2235          inference is being loaded. This method returns the status object, and
2236          so may be chained with other assertions.
2237
2238      * `assert_nontrivial_match()`: Asserts that something aside from the root
2239          object was matched. This is a very weak assertion, but is useful for
2240          sanity checking in library code where objects may exist in the
2241          checkpoint which haven't been created in Python and some Python
2242          objects may not have a checkpointed value.
2243
2244      * `expect_partial()`: Silence warnings about incomplete checkpoint
2245          restores. Warnings are otherwise printed for unused parts of the
2246          checkpoint file or object when the `Checkpoint` object is deleted
2247          (often at program shutdown).
2248
2249    Raises:
2250      NotFoundError: if the a checkpoint or SavedModel cannot be found at
2251        `save_path`.
2252    """
2253    orig_save_path = save_path
2254
2255    if save_path is not None and gfile.IsDirectory(save_path) and (
2256        (gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or
2257         gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))):
2258      save_path = utils_impl.get_variables_path(save_path)
2259
2260    try:
2261      status = self.read(save_path, options=options)
2262    except errors_impl.NotFoundError:
2263      raise errors_impl.NotFoundError(
2264          None, None,
2265          "Could not find checkpoint or SavedModel at {}."
2266          .format(orig_save_path))
2267    # Create the save counter now so it gets initialized with other variables
2268    # when graph building. Creating it earlier would lead to errors when using,
2269    # say, train.Saver() to save the model before initializing it.
2270    self._maybe_create_save_counter()
2271    if isinstance(status, NameBasedSaverStatus):
2272      status.add_to_optionally_restored(self.save_counter)
2273    return status
2274