• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities for saving/loading Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16
17import abc
18import collections
19import functools
20import glob
21import os
22import threading
23import time
24import weakref
25
26from tensorflow.core.protobuf import trackable_object_graph_pb2
27from tensorflow.python.checkpoint import checkpoint_management
28from tensorflow.python.checkpoint import checkpoint_options
29from tensorflow.python.checkpoint import functional_saver
30from tensorflow.python.checkpoint import graph_view as graph_view_lib
31from tensorflow.python.checkpoint import restore as restore_lib
32from tensorflow.python.checkpoint import save_util_v1
33from tensorflow.python.checkpoint import util
34from tensorflow.python.client import session as session_lib
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.eager import executor
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors_impl
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.framework import tensor_util
44from tensorflow.python.lib.io import file_io
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import gen_io_ops as io_ops
47from tensorflow.python.ops import init_ops
48from tensorflow.python.ops import variable_scope
49from tensorflow.python.ops import variables
50from tensorflow.python.platform import gfile
51from tensorflow.python.platform import tf_logging as logging
52from tensorflow.python.saved_model import utils_impl
53from tensorflow.python.saved_model.pywrap_saved_model import metrics
54from tensorflow.python.trackable import autotrackable
55from tensorflow.python.trackable import base
56from tensorflow.python.trackable import data_structures
57from tensorflow.python.training import py_checkpoint_reader
58from tensorflow.python.training import saver as v1_saver_lib
59from tensorflow.python.training.saving import saveable_object_util
60from tensorflow.python.util import compat
61from tensorflow.python.util import deprecation
62from tensorflow.python.util import object_identity
63from tensorflow.python.util import tf_contextlib
64from tensorflow.python.util import tf_inspect
65from tensorflow.python.util.tf_export import tf_export
66
67# The callable that provide Keras default session that is needed for saving.
68_SESSION_PROVIDER = None
69
70# Captures the timestamp of the first Checkpoint instantiation or end of a write
71# operation. Can be accessed by multiple Checkpoint instances.
72_END_TIME_OF_LAST_WRITE = None
73_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
74
75# API labels for cell names used in checkpoint metrics.
76_CHECKPOINT_V1 = "checkpoint_v1"
77_CHECKPOINT_V2 = "checkpoint_v2"
78
79# Async thread used for asynchronous checkpoint.
80_ASYNC_CHECKPOINT_THREAD = None
81
82
83def _get_duration_microseconds(start_time_seconds, end_time_seconds):
84  if end_time_seconds < start_time_seconds:
85    # Avoid returning negative value in case of clock skew.
86    return 0
87  return round((end_time_seconds - start_time_seconds) * 1000000)
88
89
90@tf_export("__internal__.tracking.register_session_provider", v1=[])
91def register_session_provider(session_provider):
92  global _SESSION_PROVIDER
93  # TODO(scottzhu): Change it back to only allow one time setting for session
94  # provider once we finished the keras repo split.
95  # if _SESSION_PROVIDER is None:
96  _SESSION_PROVIDER = session_provider
97
98
99def get_session():
100  # Prefer TF's default session since get_session from Keras has side-effects.
101  session = ops.get_default_session()
102  if session is None:
103    global _SESSION_PROVIDER
104    if _SESSION_PROVIDER is not None:
105      session = _SESSION_PROVIDER()  # pylint: disable=not-callable
106  return session
107
108
109def _get_checkpoint_size(prefix):
110  """Calculates filesize of checkpoint based on prefix."""
111  size = 0
112  # Gather all files beginning with prefix (.index plus sharded data files).
113  files = glob.glob("{}*".format(prefix))
114  for file in files:
115    # Use TensorFlow's C++ FileSystem API.
116    size += metrics.CalculateFileSize(file)
117  return size
118
119
120class ObjectGraphProtoPrettyPrinter:
121  """Lazily traverses an object graph proto to pretty print names.
122
123  If no calls to `node_names` are made this object has no performance
124  overhead. On the other hand, it will only traverse the object graph once, so
125  repeated naming is cheap after the first.
126  """
127
128  __slots__ = ["_object_graph_proto", "_node_name_cache"]
129
130  def __init__(self, object_graph_proto):
131    self._object_graph_proto = object_graph_proto
132    self._node_name_cache = None
133
134  @property
135  def node_names(self):
136    """Lazily creates a mapping from node id to ("path", "to", "root")."""
137    if self._node_name_cache is not None:
138      return self._node_name_cache
139    path_to_root = {}
140    path_to_root[0] = ("(root)",)
141    to_visit = collections.deque([0])
142    while to_visit:
143      node_id = to_visit.popleft()
144      obj = self._object_graph_proto.nodes[node_id]
145      for child in obj.children:
146        if child.node_id not in path_to_root:
147          path_to_root[child.node_id] = (
148              path_to_root[node_id] + (child.local_name,))
149          to_visit.append(child.node_id)
150
151    node_names = {}
152    for node_id, path_to_root in path_to_root.items():
153      node_names[node_id] = ".".join(path_to_root)
154
155    for node_id, node in enumerate(self._object_graph_proto.nodes):
156      for slot_reference in node.slot_variables:
157        node_names[slot_reference.slot_variable_node_id] = (
158            f"{node_names[node_id]}'s state '{slot_reference.slot_name}' for "
159            f"{node_names[slot_reference.original_variable_node_id]}")
160    self._node_name_cache = node_names
161    return node_names
162
163
164class _CheckpointRestoreCoordinatorDeleter:
165  """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
166
167  __slots__ = [
168      "expect_partial", "object_graph_proto", "matched_proto_ids",
169      "unused_attributes"
170  ]
171
172  def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
173               unused_attributes):
174    self.expect_partial = expect_partial
175    self.object_graph_proto = object_graph_proto
176    self.matched_proto_ids = matched_proto_ids
177    self.unused_attributes = unused_attributes
178
179  def set_expect_partial(self, expect_partial):
180    self.expect_partial = expect_partial
181
182  def __del__(self):
183    if self.expect_partial:
184      return
185    if logging is None:
186      # The logging module may have been unloaded when __del__ is called.
187      log_fn = print
188    else:
189      log_fn = logging.warning
190    unused_nodes_in_checkpoint = []
191    unrestored_attributes_in_object = []
192    pretty_printer = ObjectGraphProtoPrettyPrinter(self.object_graph_proto)
193    for node_id, node in enumerate(self.object_graph_proto.nodes):
194      if not node.attributes:
195        continue
196      if node_id not in self.matched_proto_ids:
197        unused_nodes_in_checkpoint.append(pretty_printer.node_names[node_id])
198    for node_id, attribute_name in self.unused_attributes.items():
199      unrestored_attributes_in_object.append((
200          pretty_printer.node_names[node_id], attribute_name))
201    if unused_nodes_in_checkpoint or unrestored_attributes_in_object:
202      # pylint:disable=line-too-long
203      log_fn("Detecting that an object or model or tf.train.Checkpoint is being"
204             " deleted with unrestored values. See the following logs for the "
205             "specific values in question. To silence these warnings, use "
206             "`status.expect_partial()`. See "
207             "https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restore"
208             "for details about the status object returned by the restore "
209             "function.")
210      # pylint:enable=line-too-long
211      for node_path in unused_nodes_in_checkpoint:
212        log_fn("Value in checkpoint could not be found in the restored object: "
213               f"{node_path}")
214      for node_path, attr in unrestored_attributes_in_object:
215        log_fn("An attribute in the restored object could not be found in the "
216               f"checkpoint. Object: {node_path}, attribute: {attr}")
217
218
219class _CheckpointRestoreCoordinator:
220  """Holds the status of an object-based checkpoint load."""
221
222  def __init__(self, object_graph_proto, save_path, save_path_tensor, reader,
223               restore_op_cache, graph_view, options, saveables_cache):
224    """Specify the checkpoint being loaded.
225
226    Args:
227      object_graph_proto: The TrackableObjectGraph protocol buffer associated
228        with this checkpoint.
229      save_path: A string, the path to the checkpoint, as returned by
230        `tf.train.latest_checkpoint`.
231      save_path_tensor: A string `Tensor` which contains or will be fed the save
232        path.
233      reader: A `CheckpointReader` for `save_path`. If None,
234        `_CheckpointRestoreCoordinator` will initialize one itself.
235      restore_op_cache: A dictionary shared between
236        `_CheckpointRestoreCoordinator`s for the same Python objects, used to
237        look up restore ops by name to avoid re-creating them across multiple
238        `restore()` calls.
239      graph_view: A graph_view_lib.ObjectGraphView object for the restored
240        objects.
241      options: A CheckpointOptions object.
242      saveables_cache: An optional cache storing previously created
243        SaveableObjects created for each Trackable. Maps Trackables to a
244        dictionary of attribute names to Trackable.
245    """
246    self.options = options
247    self.object_graph_proto = object_graph_proto
248    self.restore_uid = ops.uid()
249    # Maps from proto ids to lists of attributes which were in the checkpoint
250    # but not loaded into any object, for error checking.
251    self.unused_attributes = {}
252    # Dictionary mapping from an id in the protocol buffer flat array to
253    # Trackable Python objects. This mapping may be deferred if a
254    # checkpoint is restored before all dependencies have been tracked. Uses
255    # weak references so that partial restorations don't create reference cycles
256    # (as objects with deferred dependencies will generally have references to
257    # this object).
258    self.object_by_proto_id = weakref.WeakValueDictionary()
259    self.matched_proto_ids = set()
260    # A set of all Python objects we've seen as dependencies, even if we didn't
261    # use them (for example because of inconsistent references when
262    # loading). Used to make status assertions fail when loading checkpoints
263    # that don't quite match.
264    self.all_python_objects = object_identity.ObjectIdentityWeakSet()
265    self.save_path_tensor = save_path_tensor
266    self.save_path_string = save_path
267    self.reader = reader
268    if self.reader is None:
269      self.reader = py_checkpoint_reader.NewCheckpointReader(save_path)
270    self.dtype_map = reader.get_variable_to_dtype_map()
271    self.shape_map = reader.get_variable_to_shape_map()
272    # A NewCheckpointReader for the most recent checkpoint, for streaming Python
273    # state restoration.
274    # When graph building, contains a list of ops to run to restore objects from
275    # this checkpoint.
276    self.restore_ops = []
277    self.restore_ops_by_name = restore_op_cache
278    self.graph_view = graph_view
279    self.new_restore_ops_callback = None
280    # A mapping from optimizer proto ids to lists of slot variables to be
281    # restored when the optimizer is tracked. Only includes slot variables whose
282    # regular variables have already been created, and only for optimizer
283    # objects which have not yet been created/tracked.
284    self.deferred_slot_restorations = {}
285    # A mapping from variable proto ids to lists of slot variables to be
286    # restored when the variable is created/tracked. These get shifted over to
287    # deferred_slot_restorations if the optimizer hasn't been created when that
288    # happens.
289    self.slot_restorations = {}
290    # Controls whether errors are printed in __del__ if some objects did not
291    # match.
292    self.expect_partial_attr = False
293    for node_index, node in enumerate(self.object_graph_proto.nodes):
294      for slot_reference in node.slot_variables:
295        # `node` refers to an `Optimizer`, since only these have slot variables.
296        self.slot_restorations.setdefault(
297            slot_reference.original_variable_node_id, []).append(
298                base._SlotVariableRestoration(  # pylint: disable=protected-access
299                    optimizer_id=node_index,
300                    slot_variable_id=slot_reference.slot_variable_node_id,
301                    slot_name=slot_reference.slot_name))
302
303    self._deleter = _CheckpointRestoreCoordinatorDeleter(
304        self.expect_partial_attr,
305        self.object_graph_proto,
306        self.matched_proto_ids,
307        self.unused_attributes)
308
309    self.saveables_cache = saveables_cache
310
311  @property
312  def expect_partial(self):
313    return self.expect_partial_attr
314
315  @expect_partial.setter
316  def expect_partial(self, expect_partial):
317    self.expect_partial_attr = expect_partial
318    self._deleter.set_expect_partial(expect_partial)
319
320  def new_restore_ops(self, new_ops):
321    self.restore_ops.extend(new_ops)
322    if self.new_restore_ops_callback:
323      self.new_restore_ops_callback(new_ops)  # pylint: disable=not-callable
324
325  def restore_saveables(self,
326                        tensor_saveables,
327                        python_positions,
328                        registered_savers=None):
329    """Run or build restore operations for SaveableObjects.
330
331    Args:
332      tensor_saveables: `SaveableObject`s which correspond to Tensors.
333      python_positions: List of CheckpointPositions bound to `PythonState`
334        objects which must be restored eagerly.
335      registered_savers: a dict mapping saver names-> object name -> Trackable.
336
337    Returns:
338      When graph building, a list of restore operations, either cached or newly
339      created, to restore `tensor_saveables`.
340    """
341    restore_ops = []
342    # Eagerly run restorations for Python state.
343    for position in python_positions:
344      key = position.object_proto.attributes[0].checkpoint_key
345      position.trackable.deserialize(self.reader.get_tensor(key))
346
347    # If we have new SaveableObjects, extract and cache restore ops.
348    if tensor_saveables or registered_savers:
349      flat_saveables = saveable_object_util.validate_and_slice_inputs(
350          tensor_saveables)
351      new_restore_ops = functional_saver.MultiDeviceSaver(
352          flat_saveables,
353          registered_savers).restore(self.save_path_tensor, self.options)
354      if not context.executing_eagerly():
355        for name, restore_op in sorted(new_restore_ops.items()):
356          restore_ops.append(restore_op)
357          assert name not in self.restore_ops_by_name
358          self.restore_ops_by_name[name] = restore_op
359    return restore_ops
360
361
362class _NameBasedRestoreCoordinator:
363  """Keeps the status of a name-based checkpoint restore."""
364
365  def __init__(self, save_path, dtype_map=None):
366    self.save_path = save_path
367    self.dtype_map = dtype_map
368    # A map from trackable objects to unused attribute names. We don't have
369    # proto IDs when doing a name-based restore, so the map keys differ from
370    # those in _CheckpointRestoreCoordinator.
371    self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary()
372    self.restore_uid = ops.uid()
373
374  def globally_named_object_attributes(self, trackable):
375    """Create globally named SaveableObjects from attributes.
376
377    If an object's attribute has no global name specified (default construction
378    for the SaveableObject factory), records the failure in
379    `self.unused_attributes` (which can then be used to make status assertions
380    fail; see `NameBasedSaverStatus`).
381
382    Args:
383      trackable: An object to save.
384
385    Yields:
386      SaveableObjects for `trackable`'s attributes.
387    """
388    for attribute_name, saveable_factory in (
389        trackable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
390      if callable(saveable_factory):
391        try:
392          # This saveable object factory does not have a default name= argument,
393          # which means there's no way to save/restore it using a name-based
394          # checkpoint. Ignore the error now and make sure assert_consumed()
395          # fails.
396          saveable = saveable_factory()
397        except TypeError:
398          self.unused_attributes.setdefault(trackable,
399                                            []).append(attribute_name)
400          continue
401      else:
402        saveable = saveable_factory
403      names_to_saveables = saveable_object_util.op_list_to_dict(
404          [saveable], convert_variable_to_tensor=False)
405      for name, op in names_to_saveables.items():
406        for saveable_object in saveable_object_util.saveable_objects_for_op(
407            op=op, name=name):
408          yield saveable_object
409
410  def eager_restore(self, trackable):
411    """Runs restore ops for `trackable`'s attributes."""
412    # When graph building, we don't add any restore ops to the graph until
413    # run_restore_ops/initialize_or_restore on the status object for name-based
414    # checkpoints.
415    assert context.executing_eagerly()
416    for saveable in self.globally_named_object_attributes(trackable):
417      restored_tensors = []
418      tensor_missing = False
419      for spec in saveable.specs:
420        if spec.name in self.dtype_map:
421          with ops.device("cpu:0"):
422            restored, = io_ops.restore_v2(
423                prefix=self.save_path,
424                tensor_names=[spec.name],
425                shape_and_slices=[""],
426                dtypes=[self.dtype_map[spec.name]],
427                name="%s_checkpoint_read" % (spec.name,))
428          restored_tensors.append(array_ops.identity(restored))
429        else:
430          tensor_missing = True
431
432      if tensor_missing:
433        # Record that this variable didn't match so assertions will fail.
434        self.unused_attributes.setdefault(trackable, []).append(saveable.name)
435      else:
436        # Ignores values missing from the checkpoint, as with object-based
437        # restore. Status assertions can be used to check exact matches,
438        # although it's unlikely to ever happen for name-based checkpoints.
439        saveable.restore(
440            restored_tensors=restored_tensors, restored_shapes=None)
441
442
443# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change
444# or consolidating the implementation with get_variable.
445def _default_getter(name,
446                    shape,
447                    dtype,
448                    initializer=None,
449                    partition_info=None,
450                    **kwargs):
451  """A pared-down version of get_variable which does not reuse variables."""
452  dtype = dtypes.as_dtype(dtype)
453  shape_object = tensor_shape.as_shape(shape)
454  with ops.init_scope():
455    if initializer is None:
456      initializer, initializing_from_value = (
457          variable_scope._get_default_variable_store()._get_default_initializer(  # pylint: disable=protected-access
458              name=name,
459              shape=shape_object,
460              dtype=dtype))
461    else:
462      initializing_from_value = not callable(initializer)
463    # Same logic as get_variable
464    variable_dtype = dtype.base_dtype
465    if initializing_from_value:
466      if shape is not None:
467        raise ValueError("If initializer is a constant, do not specify shape.")
468      initial_value = initializer
469    else:
470      # Instantiate initializer if provided initializer is a type object.
471      if isinstance(initializer, type(init_ops.Initializer)):
472        initializer = initializer(dtype=dtype)
473      shape_list = None if shape is None else shape_object.as_list()
474      if "partition_info" in tf_inspect.getargspec(initializer).args:
475        initial_value = functools.partial(initializer,
476                                          shape_list,
477                                          dtype=dtype,
478                                          partition_info=partition_info)
479      else:
480        initial_value = functools.partial(initializer,
481                                          shape_list,
482                                          dtype=dtype)
483
484    return variables.VariableV1(
485        initial_value=initial_value,
486        name=name,
487        dtype=variable_dtype,
488        use_resource=True,
489        **kwargs)
490
491
492def add_variable(trackable,
493                 name,
494                 shape=None,
495                 dtype=dtypes.float32,
496                 initializer=None,
497                 trainable=True):
498  """Add a variable to a Trackable with no scope influence."""
499  return trackable._add_variable_with_custom_getter(  # pylint: disable=protected-access
500      name=name,
501      shape=shape,
502      dtype=dtype,
503      initializer=initializer,
504      getter=_default_getter,
505      trainable=trainable)
506
507
508def object_metadata(save_path):
509  """Retrieves information about the objects in a checkpoint.
510
511  Example usage:
512
513  ```python
514  object_graph = tf.contrib.checkpoint.object_metadata(
515      tf.train.latest_checkpoint(checkpoint_directory))
516  ckpt_variable_names = set()
517  for node in object_graph.nodes:
518    for attribute in node.attributes:
519      ckpt_variable_names.add(attribute.full_name)
520  ```
521
522  Args:
523    save_path: The path to the checkpoint, as returned by `save` or
524      `tf.train.latest_checkpoint`.
525
526  Returns:
527    A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
528  Raises:
529    ValueError: If an object graph was not found in the checkpoint.
530  """
531  reader = py_checkpoint_reader.NewCheckpointReader(save_path)
532  try:
533    object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
534  except errors_impl.NotFoundError:
535    raise ValueError(
536        f"The specified checkpoint \"{save_path}\" does not appear to be "
537        "object-based (saved with TF2) since it is missing the key "
538        f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
539        "TF1 name-based saver and does not contain an object dependency graph.")
540  object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
541  object_graph_proto.ParseFromString(object_graph_string)
542  return object_graph_proto
543
544
545def list_objects(root_trackable):
546  """Traverse the object graph and list all accessible objects.
547
548  Looks for `Trackable` objects which are dependencies of
549  `root_trackable`. Includes slot variables only if the variable they are
550  slotting for and the optimizer are dependencies of `root_trackable`
551  (i.e. if they would be saved with a checkpoint).
552
553  Args:
554    root_trackable: A `Trackable` object whose dependencies should be flattened.
555
556  Returns:
557    A flat list of objects.
558  """
559  return util.list_objects(graph_view_lib.ObjectGraphView(root_trackable))
560
561
562def gather_initializers(root_trackable):
563  """Traverse the object graph and find initialization ops.
564
565  Looks for `Trackable` objects which are dependencies of
566  `root_trackable` and which have an `initializer` property. Includes
567  initializers for slot variables only if the variable they are slotting for and
568  the optimizer are dependencies of `root_trackable` (i.e. if they would be
569  saved with a checkpoint).
570
571  Args:
572    root_trackable: A `Trackable` object to gather initializers for.
573
574  Returns:
575    A list of initialization ops.
576  """
577  trackable_objects = list_objects(root_trackable)
578  return [
579      c.initializer
580      for c in trackable_objects
581      if hasattr(c, "initializer") and c.initializer is not None
582  ]
583
584
585@tf_contextlib.contextmanager
586def capture_dependencies(template):
587  """Capture variables created within this scope as `Template` dependencies.
588
589  Requires that `template.variable_scope` is active.
590
591  This scope is intended as a compatibility measure, allowing a trackable
592  object to add dependencies on variables created in a block of code which is
593  not aware of object-based saving (and instead uses variable names
594  heavily). This is how `Template` objects add dependencies on variables and
595  sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
596
597  Args:
598    template: The `Template` object to register dependencies with.
599
600  Yields:
601    None (when used as a context manager).
602  """
603  name_prefix = template.variable_scope.name
604
605  def _trackable_custom_creator(next_creator,
606                                name,
607                                initial_value,
608                                trackable_parent=None,
609                                **kwargs):
610    """A variable creation hook which adds Trackable dependencies.
611
612    Set for example during a `Template`'s first wrapped function
613    execution. Ensures that (a) `template` depends on any trackable
614    objects using their own `capture_dependencies` scope inside this scope which
615    create variables, and (b) that any variables not in a more deeply nested
616    scope are added as dependencies directly.
617
618    The `trackable_parent` argument is passed between custom creators but
619    ignored when the variable object itself is created. This argument indicates
620    (if not `None`) that a more deeply nested scope has already added the
621    variable as a dependency, and that parent scopes should add a dependency on
622    that object rather than on the variable directly.
623
624    Args:
625      next_creator: See `variable_scope.variable_creator_scope`; the next
626        creator in the chain.
627      name: The (full, scope-influenced) name of the variable. The `name_prefix`
628        itself is stripped for the purposes of object-based dependency tracking,
629        but scopes opened within this scope are respected.
630      initial_value: See `variable_scope.variable_creator_scope`. Taken
631        explicitly so the argument can be re-named and used with
632        `Trackable._add_variable_with_custom_getter`.
633      trackable_parent: If not None, a more deeply nested trackable object and
634        its name prefix which were passed to `capture_dependencies` to add a
635        dependency on (rather than depending on the variable directly).
636      **kwargs: Passed through to the next creator.
637
638    Returns:
639      The output of `next_creator`: the fetched/created variable object.
640    """
641
642    def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
643      inner_kwargs.pop("name")  # Ignored; this is the scope-stripped name which
644      # we don't want to propagate.
645      return next_creator(initial_value=initializer, name=name, **inner_kwargs)
646
647    if name is not None and name.startswith(name_prefix):
648      scope_stripped_name = name[len(name_prefix) + 1:]
649      if not trackable_parent:
650        return template._add_variable_with_custom_getter(  # pylint: disable=protected-access
651            initializer=initial_value,
652            name=scope_stripped_name,
653            getter=_call_next_creator_renaming_initializer,
654            # Disable error checking for Trackable. Exceptions are instead
655            # raised if necessary when the object-based saver tries to
656            # save/restore the object.
657            overwrite=True,
658            trackable_parent=(template, name_prefix),
659            **kwargs)
660      else:
661        parent_object, parent_name_prefix = trackable_parent
662        template._track_trackable(  # pylint: disable=protected-access
663            parent_object,
664            name=parent_name_prefix[len(name_prefix) + 1:],
665            overwrite=True)
666    return next_creator(
667        name=name,
668        initial_value=initial_value,
669        trackable_parent=(template, name_prefix),
670        **kwargs)
671
672  with variable_scope.variable_creator_scope(_trackable_custom_creator):
673    yield
674
675
676class _LoadStatus:
677  """Abstract base for load status callbacks."""
678
679  @abc.abstractmethod
680  def assert_consumed(self):
681    """Raises an exception unless a non-trivial restoration has completed."""
682    pass
683
684  @abc.abstractmethod
685  def assert_existing_objects_matched(self):
686    """Raises an exception unless existing Python objects have been matched."""
687    pass
688
689  @abc.abstractmethod
690  def assert_nontrivial_match(self):
691    """Raises an exception if only the root object matched."""
692    pass
693
694  @abc.abstractmethod
695  def run_restore_ops(self, session=None):
696    """Runs restore ops from the checkpoint. Requires a valid checkpoint."""
697    pass
698
699  @abc.abstractmethod
700  def initialize_or_restore(self, session=None):
701    """Runs restore ops from the checkpoint, or initializes variables."""
702    pass
703
704  def expect_partial(self):
705    """Silence warnings about incomplete checkpoint restores."""
706    return self
707
708
709@tf_export("__internal__.tracking.streaming_restore", v1=[])
710def streaming_restore(status, session=None):
711  """When graph building, runs restore ops as soon as they come in.
712
713  Args:
714    status: A _LoadStatus objects from an object-based saver's restore().
715      Streaming restore from name-based checkpoints is not currently supported.
716    session: A session to run new restore ops in.
717  """
718  if context.executing_eagerly():
719    # Streaming restore is the default/only behavior when executing eagerly.
720    return
721  if session is None:
722    session = get_session()
723  if isinstance(status, NameBasedSaverStatus):
724    raise NotImplementedError(
725        "Streaming restore not supported from name-based checkpoints when "
726        "graph building. File a feature request if this limitation bothers "
727        "you. As a workaround, consider either using tf.train.Checkpoint to "
728        "load name-based checkpoints or enabling eager execution.")
729  status.run_restore_ops(session=session)
730  # pylint: disable=protected-access
731  status._checkpoint.new_restore_ops_callback = (
732      lambda ops: session.run(ops, feed_dict=status._feed_dict))
733  # pylint: enable=protected-access
734
735
736def _objects_with_attributes(full_list):
737  """Filters out objects with no direct variable dependencies for assertions."""
738  return [o for o in full_list if o._gather_saveables_for_checkpoint()]  # pylint: disable=protected-access
739
740
741class CheckpointLoadStatus(_LoadStatus):
742  """Checks the status of checkpoint loading and manages restore ops.
743
744  Returned from `Saver.restore`. Since `restore` may defer the loading of values
745  in the checkpoint which don't yet have corresponding Python objects,
746  `CheckpointLoadStatus` provides a callback to verify that checkpoint loading
747  is complete (`assert_consumed`).
748
749  When graph building, `restore` does not run restore ops itself since their
750  creation may be deferred. The `run_restore_ops` method must be called once all
751  Python objects with values to restore have been created and added to the
752  dependency graph (this does not necessarily have to be the whole checkpoint;
753  calling `run_restore_ops` while `assert_consumed` fails is supported and will
754  partially restore the checkpoint).
755
756  See `Saver.restore` for usage examples.
757  """
758
759  def __init__(self, checkpoint, feed_dict, graph_view):
760    self._checkpoint = checkpoint
761    self._feed_dict = feed_dict
762    self._object_graph_view = graph_view
763    # Keep a reference to the root, since object_graph_view might only have a
764    # weakref.
765    self._root = graph_view.root
766
767  def assert_consumed(self):
768    """Asserts that all objects in the checkpoint have been created/matched.
769
770    Returns:
771      `self` for chaining.
772    Raises:
773      AssertionError: If there are any Python objects in the dependency graph
774        which have not been restored from this checkpoint or a later `restore`,
775        or if there are any checkpointed values which have not been matched to
776        Python objects.
777    """
778    pretty_printer = ObjectGraphProtoPrettyPrinter(
779        self._checkpoint.object_graph_proto)
780    self.assert_existing_objects_matched()
781    for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
782      if not node.attributes:
783        # Only raise exceptions for the nodes with attributes themselves. Either
784        # they're ultimately not important, or they have a child with an
785        # attribute.
786        continue
787      trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
788      if trackable is None:
789        raise AssertionError(
790            "Unresolved object in checkpoint "
791            f"{pretty_printer.node_names[node_id]}: {node}")
792    if self._checkpoint.slot_restorations:
793      # Sanity check; this collection should be clear if everything has been
794      # restored.
795      raise AssertionError(
796          f"Unresolved slot restorations: {self._checkpoint.slot_restorations}")
797    if self._checkpoint.unused_attributes:
798      unused_attribute_messages = []
799      for node_id, attribute in self._checkpoint.unused_attributes.items():
800        obj = self._checkpoint.object_by_proto_id[node_id]
801        unused_attribute_messages.append(
802            f"{pretty_printer.node_names[node_id]} ({obj}): {attribute}")
803      joined_attribute_messages = "\n".join(unused_attribute_messages)
804      raise AssertionError(
805          "Unused attributes in these objects (the attributes exist in the "
806          f"checkpoint but were not restored):\n{joined_attribute_messages}")
807    return self
808
809  def assert_existing_objects_matched(self):
810    """Asserts that trackable Python objects have been matched.
811
812    Note that this is a weaker assertion than `assert_consumed`. It will only
813    fail for existing Python objects which are (transitive) dependencies of the
814    root object and which do not have an entry in the checkpoint.
815
816    It will not fail, for example, if a `tf.keras.Layer` object has not yet been
817    built and so has not created any `tf.Variable` objects.
818
819    Returns:
820      `self` for chaining.
821
822    Raises:
823      AssertionError: If a Python object exists in the transitive dependencies
824        of the root object but does not have a value in the checkpoint.
825    """
826    for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
827      trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
828      if (trackable is not None and
829          trackable._update_uid < self._checkpoint.restore_uid):  # pylint: disable=protected-access
830        raise AssertionError(
831            f"Object {node} not assigned a value from checkpoint.")
832    for trackable_object in util.list_objects(self._object_graph_view):
833      # Remove data structures that do not contain any variables from
834      # restoration checks.
835      if (isinstance(trackable_object,
836                     data_structures.TrackableDataStructure) and
837          not trackable_object._trackable_children()):  # pylint: disable=protected-access
838        continue
839      self._checkpoint.all_python_objects.add(trackable_object)
840    unused_python_objects = (
841        object_identity.ObjectIdentitySet(
842            _objects_with_attributes(
843                self._checkpoint.all_python_objects)) -
844        object_identity.ObjectIdentitySet(
845            self._checkpoint.object_by_proto_id.values()))
846    if unused_python_objects:
847      num_unused_python_objects = len(list(unused_python_objects))
848      # Display max number of 10 variables in error message.
849      num_variables_to_show = min(10, num_unused_python_objects)
850      raise AssertionError(
851          f"Found {num_unused_python_objects} Python objects that were "
852          "not bound to checkpointed values, likely due to changes in the "
853          f"Python program. Showing {num_variables_to_show} of "
854          f"{num_unused_python_objects} unmatched objects: "
855          f"{list(unused_python_objects)[:num_variables_to_show]}")
856    return self
857
858  def assert_nontrivial_match(self):
859    """Raises an exception if only the root object matched."""
860    for trackable_object in util.list_objects(self._object_graph_view):
861      self._checkpoint.all_python_objects.add(trackable_object)
862    if len(self._checkpoint.object_by_proto_id) <= 1:
863      unused_python_objects = (
864          object_identity.ObjectIdentitySet(
865              _objects_with_attributes(self._checkpoint.all_python_objects)) -
866          object_identity.ObjectIdentitySet(
867              self._checkpoint.object_by_proto_id.values()))
868      if unused_python_objects:
869        raise AssertionError(
870            "Nothing except the root object matched a checkpointed value. "
871            "Typically this means that the checkpoint does not match the "
872            "Python program. The following objects have no matching "
873            f"checkpointed value: {list(unused_python_objects)}")
874      else:
875        raise AssertionError(
876            "Nothing to load. No dependencies have been added to "
877            f"{self._object_graph_view.root} yet.")
878    return self
879
880  def run_restore_ops(self, session=None):
881    """Run operations to restore objects in the dependency graph."""
882    if context.executing_eagerly():
883      return  # Run eagerly
884    if session is None:
885      session = get_session()
886    session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
887
888  def initialize_or_restore(self, session=None):
889    """Run operations to initialize or restore objects in the dependency graph.
890
891    Any objects in the dependency graph which have initializers but are not in
892    the checkpoint will have those initializers run, unless those variables are
893    being restored by a later call to `tf.train.Checkpoint.restore()`.
894
895    This method has a sibling in `InitializationOnlyStatus` which instead
896    initializes variables. That type is returned if no checkpoint is specified
897    in `Saver.restore`.
898
899    Args:
900      session: The session to run init/restore ops in. If `None`, uses the
901        default session.
902    """
903    if context.executing_eagerly():
904      return  # Initialization and restoration ops are run eagerly
905    if session is None:
906      session = get_session()
907    all_objects = util.list_objects(self._object_graph_view)
908    already_initialized_objects = object_identity.ObjectIdentitySet(
909        self._checkpoint.object_by_proto_id.values())
910    initializers_for_non_restored_variables = [
911        c.initializer for c in all_objects
912        if hasattr(c, "initializer")
913        and c not in already_initialized_objects
914        and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1)
915             < self._checkpoint.restore_uid)
916    ]
917    self.run_restore_ops(session=session)
918    session.run(initializers_for_non_restored_variables)
919
920  def expect_partial(self):
921    """Silence warnings about incomplete checkpoint restores."""
922    self._checkpoint.expect_partial = True
923    return self
924
925
926class InitializationOnlyStatus(_LoadStatus):
927  """Returned from `Saver.restore` when no checkpoint has been specified.
928
929  Objects of this type have the same `assert_consumed` method as
930  `CheckpointLoadStatus`, but it always fails. However,
931  `initialize_or_restore` works on objects of both types, and will
932  initialize variables in `InitializationOnlyStatus` objects or restore them
933  otherwise.
934  """
935
936  def __init__(self, object_graph_view, restore_uid):
937    self._restore_uid = restore_uid
938    self._object_graph_view = object_graph_view
939    # Keep a reference to the root, since graph_view might only have a weakref.
940    self._root = object_graph_view.root
941
942  def assert_consumed(self):
943    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
944    raise AssertionError(
945        "No checkpoint specified (save_path=None); nothing is being restored.")
946
947  def assert_existing_objects_matched(self):
948    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
949    raise AssertionError(
950        "No checkpoint specified (save_path=None); nothing is being restored.")
951
952  def assert_nontrivial_match(self):
953    """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
954    raise AssertionError(
955        "No checkpoint specified (save_path=None); nothing is being restored.")
956
957  def run_restore_ops(self, session=None):
958    """For consistency with `CheckpointLoadStatus`.
959
960    Use `initialize_or_restore` for initializing if no checkpoint was passed
961    to `Saver.restore` and restoring otherwise.
962
963    Args:
964      session: Not used.
965    """
966    raise AssertionError(
967        "No checkpoint specified, so no restore ops are available "
968        "(save_path=None to Saver.restore).")
969
970  def initialize_or_restore(self, session=None):
971    """Runs initialization ops for variables.
972
973    Objects which would be saved by `Saver.save` will be initialized, unless
974    those variables are being restored by a later call to
975    `tf.train.Checkpoint.restore()`.
976
977    This method does nothing when executing eagerly (initializers get run
978    eagerly).
979
980    Args:
981      session: The session to run initialization ops in. If `None`, uses the
982        default session.
983    """
984    if context.executing_eagerly():
985      return  # run eagerly
986    if session is None:
987      session = get_session()
988    trackable_objects = util.list_objects(self._object_graph_view)
989    initializers = [
990        c.initializer for c in trackable_objects
991        if hasattr(c, "initializer") and c.initializer is not None
992        and (getattr(c, "_update_uid", self._restore_uid - 1)
993             < self._restore_uid)
994    ]
995    session.run(initializers)
996
997
998_DEPRECATED_RESTORE_INSTRUCTIONS = (
999    "Restoring a name-based tf.train.Saver checkpoint using the object-based "
1000    "restore API. This mode uses global names to match variables, and so is "
1001    "somewhat fragile. It also adds new restore ops to the graph each time it "
1002    "is called when graph building. Prefer re-encoding training checkpoints in "
1003    "the object-based format: run save() on the object-based saver (the same "
1004    "one this message is coming from) and use that checkpoint in the future.")
1005
1006
1007class NameBasedSaverStatus(_LoadStatus):
1008  """Status for loading a name-based training checkpoint."""
1009
1010  # Ideally this deprecation decorator would be on the class, but that
1011  # interferes with isinstance checks.
1012  @deprecation.deprecated(
1013      date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
1014  def __init__(self, checkpoint, object_graph_view):
1015    self._checkpoint = checkpoint
1016    self._object_graph_view = object_graph_view
1017    self._optionally_restored = []
1018    # Keep a reference to the root, since graph_view might only have a weakref.
1019    self._root = object_graph_view.root
1020
1021  def add_to_optionally_restored(self, var):
1022    """Add a variable to the list of optionally restored variables.
1023
1024    There are situations where certain variables should be ignored in assertions
1025    such as assert_existing_objects_matched(). One example is that of a
1026    checkpoint saved with train.Saver(), and restored with train.Checkpoint():
1027    it is possible for the train.Saver() checkpoint to be missing the internal
1028    `save_counter` variable, which we want to ignore on restore.
1029
1030    Args:
1031      var: The variable to treat as optionally restored.
1032    """
1033    self._optionally_restored.append(var)
1034
1035  def assert_consumed(self):
1036    """Raises an exception if any variables are unmatched."""
1037    unused_attributes = list(self._checkpoint.unused_attributes.items())
1038    unused_attributes = [
1039        a for a in unused_attributes
1040        if all(a[0] is not x for x in self._optionally_restored)
1041    ]
1042    if unused_attributes:
1043      unused_attribute_strings = [
1044          f"\n    {obj}: {attributes}" for obj, attributes in unused_attributes]
1045      raise AssertionError(
1046          "Some objects had attributes which were not restored: "
1047          f"{unused_attribute_strings}")
1048    for trackable in util.list_objects(self._object_graph_view):
1049      # pylint: disable=protected-access
1050      trackable._maybe_initialize_trackable()
1051      if trackable._update_uid < self._checkpoint.restore_uid:
1052        raise AssertionError(f"Object not restored: {trackable}")
1053      # pylint: enable=protected-access
1054    return self
1055
1056  def assert_existing_objects_matched(self):
1057    """Raises an exception if currently created objects are unmatched."""
1058    # For name-based checkpoints there's no object information in the
1059    # checkpoint, so there's no distinction between
1060    # assert_existing_objects_matched and assert_consumed (and both are less
1061    # useful since we don't touch Python objects or Python state).
1062    return self.assert_consumed()
1063
1064  def assert_nontrivial_match(self):
1065    """Raises an exception if currently created objects are unmatched."""
1066    # For name-based checkpoints there's no object information in the
1067    # checkpoint, so there's no distinction between
1068    # assert_nontrivial_match and assert_consumed (and both are less
1069    # useful since we don't touch Python objects or Python state).
1070    return self.assert_consumed()
1071
1072  def _gather_saveable_objects(self):
1073    """Walk the object graph, using global names for SaveableObjects."""
1074    objects = util.list_objects(self._object_graph_view)
1075    saveable_objects = []
1076    for trackable in objects:
1077      # pylint: disable=protected-access
1078      trackable._maybe_initialize_trackable()
1079      if trackable._update_uid < self._checkpoint.restore_uid:
1080        trackable._update_uid = self._checkpoint.restore_uid
1081      else:
1082        continue
1083      # pylint: enable=protected-access
1084      saveable_objects.extend(
1085          self._checkpoint.globally_named_object_attributes(trackable))
1086    return saveable_objects
1087
1088  def run_restore_ops(self, session=None):
1089    """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
1090    if context.executing_eagerly():
1091      return  # Nothing to do, variables are restored on creation.
1092    if session is None:
1093      session = get_session()
1094    with ops.device("/cpu:0"):
1095      saveables = self._gather_saveable_objects()
1096      v1_saver_lib.Saver(saveables).restore(
1097          sess=session, save_path=self._checkpoint.save_path)
1098
1099  def initialize_or_restore(self, session=None):
1100    """Alias for `run_restore_ops`."""
1101    self.run_restore_ops(session=session)
1102
1103
1104class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
1105  """Pretends to be a session, inserts extra feeds on run()."""
1106
1107  def __init__(self, session, feed_additions):
1108    self._wrapped_session = session
1109    self._feed_additions = feed_additions
1110
1111  def run(self, fetches, feed_dict=None, **kwargs):
1112    if feed_dict is None:
1113      feed_dict = {}
1114    else:
1115      feed_dict = feed_dict.copy()
1116    feed_dict.update(self._feed_additions)
1117    return self._wrapped_session.run(
1118        fetches=fetches, feed_dict=feed_dict, **kwargs)
1119
1120
1121class TrackableSaver:
1122  """Saves and restores a `Trackable` object and its dependencies.
1123
1124  See `Trackable` for details of dependency management. `Saver` wraps
1125  `tf.compat.v1.train.Saver` for saving, including extra information about the
1126  graph of
1127  dependencies between Python objects. When restoring, it uses this information
1128  about the save-time dependency graph to more robustly match objects with their
1129  checkpointed values. When executing eagerly, it supports restoring variables
1130  on object creation (see `Saver.restore`).
1131
1132  Values in a checkpoint are mapped to `Trackable` Python objects
1133  (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the
1134  checkpoint was written. To avoid breaking existing checkpoints when modifying
1135  a class, dependency names (the names of attributes to which `Trackable`
1136  objects are assigned) may not change. These names are local to objects, in
1137  contrast to the `Variable.name`-based save/restore from
1138  `tf.compat.v1.train.Saver`, and
1139  so allow additional program transformations.
1140  """
1141
1142  def __init__(self, graph_view):
1143    """Configure saving.
1144
1145    Args:
1146      graph_view: An `ObjectGraphView` object containing a description of the
1147        object graph to save.
1148    """
1149    self._graph_view = graph_view
1150
1151    # The following attributes are used when graph building.
1152
1153    # Saveables caching: A dictionary mapping `Trackable` objects ->
1154    # attribute names -> SaveableObjects, used to avoid re-creating
1155    # SaveableObjects when graph building.
1156    if context.executing_eagerly():
1157      self._saveables_cache = None
1158    else:
1159      self._saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
1160
1161    # The file prefix placeholder is created lazily when graph building (and not
1162    # at all when executing eagerly) to avoid creating ops in the constructor
1163    # (when they may never be necessary).
1164    self._file_prefix_placeholder = None
1165
1166    # Op caching for save
1167    self._object_graph_feed_tensor = None
1168    self._last_save_object_graph = None
1169    self._file_prefix_feed_tensor = None
1170    self._cached_save_operation = None
1171
1172    # Op caching for restore, shared between _CheckpointRestoreCoordinators
1173    self._restore_op_cache = {}
1174
1175    # Object map used for checkpoint. This attribute is to be overridden by a
1176    # Checkpoint subclass, e.g., AsyncCheckpoint, to replace the trackable
1177    # objects for checkpoint saving.
1178    self._object_map = None
1179
1180  def _gather_saveables(self, object_graph_tensor=None):
1181    """Wraps _serialize_object_graph to include the object graph proto."""
1182    named_saveable_objects, graph_proto, feed_additions, registered_savers = (
1183        save_util_v1.serialize_gathered_objects(
1184            graph_view=self._graph_view,
1185            object_map=self._object_map,
1186            saveables_cache=self._saveables_cache))
1187
1188    if object_graph_tensor is None:
1189      with ops.device("/cpu:0"):
1190        object_graph_tensor = constant_op.constant(
1191            graph_proto.SerializeToString(), dtype=dtypes.string)
1192    else:
1193      feed_additions.update(
1194          {object_graph_tensor: graph_proto.SerializeToString()})
1195    assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
1196    named_saveable_objects.append(
1197        base.NoRestoreSaveable(
1198            tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
1199    return (named_saveable_objects, graph_proto, feed_additions,
1200            registered_savers)
1201
1202  def _save_cached_when_graph_building(self,
1203                                       file_prefix,
1204                                       object_graph_tensor,
1205                                       options,
1206                                       write_done_callback=None):
1207    """Create or retrieve save ops.
1208
1209    Args:
1210      file_prefix: The prefix for saved checkpoint files.
1211      object_graph_tensor: A `Tensor` to which the current object graph will be
1212        fed.
1213      options: `CheckpointOptions` object.
1214      write_done_callback: Optional callback function to be executed once
1215        the underlying checkpoint saving is finished. Example usage includes
1216        updating the checkpoint internal state.
1217
1218    Returns:
1219      A two-element tuple with a filename tensor and a feed_dict of tensors to
1220      feed when running it (if graph building). The feed dict contains the
1221      current object graph and any Python state to be saved in the
1222      checkpoint. When executing eagerly only the first argument is meaningful.
1223    """
1224    (named_saveable_objects, graph_proto, feed_additions,
1225     registered_savers) = self._gather_saveables(
1226         object_graph_tensor=object_graph_tensor)
1227
1228    def _run_save():
1229      """Create and execute the SaveOp for the checkpoint."""
1230      if (self._last_save_object_graph != graph_proto
1231          # When executing eagerly, we need to re-create SaveableObjects each
1232          # time save() is called so they pick up new Tensors passed to their
1233          # constructors. That means the Saver needs to be copied with a new
1234          # var_list.
1235          or context.executing_eagerly() or ops.inside_function()):
1236        saver = functional_saver.MultiDeviceSaver(named_saveable_objects,
1237                                                  registered_savers)
1238        save_op = saver.save(file_prefix, options=options)
1239        with ops.device("/cpu:0"):
1240          with ops.control_dependencies([save_op]):
1241            self._cached_save_operation = array_ops.identity(file_prefix)
1242        self._last_save_object_graph = graph_proto
1243      return self._cached_save_operation, feed_additions
1244
1245    def _copy_tensors():
1246      """Copy the tensors to the host CPU device."""
1247      for saveable in named_saveable_objects:
1248        # Pin the device according to the SaveableObject's device location to
1249        # avoid unnecessary data copies when reading the variables. This is
1250        # aligned with the behavior in MultiDeviceSaver.save().
1251        original_device = saveable.device
1252        with ops.device(original_device):
1253          for spec in saveable.specs:
1254            tensor = spec.tensor
1255            device = spec.device
1256            if tensor is not None:
1257              with ops.device(saveable_object_util.set_cpu0(device)):
1258                spec._tensor = array_ops.identity(tensor)  # pylint: disable=protected-access
1259                # Modify the device info accordingly now that the tensors are
1260                # copied to the host CPU device.
1261                spec.device = saveable_object_util.set_cpu0(device)
1262
1263    def _async_save_fn():
1264      """The thread function for executing async checkpoint save."""
1265      with context.executor_scope(
1266          executor.new_executor(
1267              enable_async=False, enable_streaming_enqueue=False)):
1268        _run_save()
1269        # Update the internal checkpoint state if the checkpoint event is
1270        # triggered from Checkpoint.save().
1271        if write_done_callback:
1272          write_done_callback(_convert_file_name_tensor_to_string(file_prefix))
1273
1274    if options.experimental_enable_async_checkpoint:
1275      # Execute async-checkpoint.
1276
1277      # Step-1: Explicitly copy the tensors to their host CPU device.
1278      _copy_tensors()
1279
1280      # Step-2: Execute the rest of the checkpoint operations on the host device
1281      #         using an async executor.
1282      global _ASYNC_CHECKPOINT_THREAD
1283      if _ASYNC_CHECKPOINT_THREAD is not None:
1284        _ASYNC_CHECKPOINT_THREAD.join()
1285      _ASYNC_CHECKPOINT_THREAD = threading.Thread(target=_async_save_fn)
1286      _ASYNC_CHECKPOINT_THREAD.start()
1287
1288      # Step-3: Return the expected checkpoint file path though the save op may
1289      #         not have finished.
1290      self._cached_save_operation = file_prefix
1291      return self._cached_save_operation, feed_additions
1292
1293    # Execute the normal checkpoint, i.e., synchronous.
1294    return _run_save()
1295
1296  def save(self, file_prefix, checkpoint_number=None, session=None,
1297           options=None, write_done_callback=None):
1298    """Save a training checkpoint.
1299
1300    The saved checkpoint includes variables created by this object and any
1301    Trackable objects it depends on at the time `Saver.save()` is called.
1302
1303    Args:
1304      file_prefix: A prefix to use for the checkpoint filenames
1305        (/path/to/directory/and_a_prefix). Names are generated based on this
1306        prefix and `checkpoint_number`, if provided.
1307      checkpoint_number: An integer variable or Tensor, used to number
1308        checkpoints. Typically this value is saved along with other variables in
1309        training checkpoints, which will happen automatically if it was created
1310        by `root_trackable` or one of its dependencies (via
1311        `Trackable._add_variable`).
1312      session: The session to evaluate variables in. Ignored when executing
1313        eagerly. If not provided when graph building, the default session is
1314        used.
1315      options: Optional `tf.train.CheckpointOptions` object.
1316      write_done_callback: Optional callback function to be executed once
1317        the underlying checkpoint saving is finished. Example usage includes
1318        updating the checkpoint internal state.
1319
1320    Returns:
1321      The full path to the checkpoint.
1322    """
1323    options = options or checkpoint_options.CheckpointOptions()
1324    feed_dict = {}
1325    use_session = (not context.executing_eagerly() and
1326                   not ops.inside_function())
1327    if checkpoint_number:
1328      file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
1329    if use_session:
1330      if self._object_graph_feed_tensor is None:
1331        with ops.device("/cpu:0"):
1332          self._object_graph_feed_tensor = constant_op.constant(
1333              "", dtype=dtypes.string)
1334          self._file_prefix_feed_tensor = constant_op.constant(
1335              "", dtype=dtypes.string)
1336      object_graph_tensor = self._object_graph_feed_tensor
1337      file_prefix_tensor = self._file_prefix_feed_tensor
1338      feed_dict[file_prefix_tensor] = file_prefix
1339    else:
1340      with ops.device("/cpu:0"):
1341        file_prefix_tensor = ops.convert_to_tensor(
1342            file_prefix, dtype=dtypes.string)
1343      object_graph_tensor = None
1344
1345    if not tensor_util.is_tensor(file_prefix):
1346      file_io.recursive_create_dir(os.path.dirname(file_prefix))
1347
1348    save_path, new_feed_additions = self._save_cached_when_graph_building(
1349        file_prefix_tensor, object_graph_tensor, options, write_done_callback)
1350
1351    if new_feed_additions:
1352      feed_dict.update(new_feed_additions)
1353    if not use_session:
1354      session = None
1355    elif session is None:
1356      session = get_session()
1357
1358    if session:
1359      return session.run(save_path, feed_dict=feed_dict)
1360    else:
1361      return save_path
1362
1363  def restore(self, save_path, options=None):
1364    """Restore a training checkpoint.
1365
1366    Restores `root_trackable` and any objects that it tracks
1367    (transitive). Either assigns values immediately if variables to restore have
1368    been created already, or defers restoration until the variables are
1369    created. Dependencies added to the `root_trackable` passed to the
1370    constructor after this call will be matched if they have a corresponding
1371    object in the checkpoint.
1372
1373    When building a graph, restorations are added to the graph but not run.
1374
1375    ```python
1376    saver = Saver(root)
1377    saver.restore(path)
1378    ```
1379
1380    To ensure that loading is complete and no more deferred restorations will
1381    take place, you can use the `assert_consumed()` method of the status object
1382    returned by the `restore` call.
1383
1384    The assert will raise an exception unless every object was matched and all
1385    checkpointed values have a matching variable object.
1386
1387    ```python
1388    saver = Saver(root)
1389    saver.restore(path).assert_consumed()
1390    ```
1391
1392    When graph building, `assert_consumed()` indicates that all of the restore
1393    ops which will be created for this checkpoint have been created. They can be
1394    run via the `run_restore_ops()` function of the status object:
1395
1396    ```python
1397    saver.restore(path).assert_consumed().run_restore_ops()
1398    ```
1399
1400    If the checkpoint has not been consumed completely, then the list of restore
1401    ops will grow as more objects are added to the dependency graph.
1402
1403    Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1404    method. There is no deferred loading, and names are used to match
1405    variables. No restore ops are created/run until `run_restore_ops()` or
1406    `initialize_or_restore()` are called on the returned status object, even
1407    when executing eagerly. Re-encode name-based checkpoints using this
1408    object-based `Saver.save` as soon as possible.
1409
1410    Args:
1411      save_path: The path to the checkpoint, as returned by `save` or
1412        `tf.train.latest_checkpoint`. If None (as when there is no latest
1413        checkpoint for `tf.train.latest_checkpoint` to return), returns an
1414        object which may run initializers for objects in the dependency graph.
1415        If the checkpoint was written by the name-based
1416        `tf.compat.v1.train.Saver`, names are used to match variables.
1417      options: Optional `tf.train.CheckpointOptions` object.
1418
1419    Returns:
1420      A load status object, which can be used to make assertions about the
1421      status of checkpoint restoration and run initialization/restore ops
1422      (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
1423      `save_path` is `None`).
1424
1425      If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
1426      object is returned which runs restore ops from a name-based saver.
1427
1428    Raises:
1429      RuntimeError: When a checkpoint file saved by async checkpoint is not
1430        available upon restore().
1431    """
1432    options = options or checkpoint_options.CheckpointOptions()
1433    if save_path is None:
1434      return InitializationOnlyStatus(self._graph_view, ops.uid())
1435
1436    # Wait until the ongoing checkpoint to finish.
1437    # TODO(chienchunh): Allow to load the file while other checkpoint events
1438    #                   are still ongiing. Need to add timeout mechanism along
1439    #                   with conditional variables to notify when the checkpoint
1440    #                   file is ready.
1441    global _ASYNC_CHECKPOINT_THREAD
1442    if _ASYNC_CHECKPOINT_THREAD is not None:
1443      _ASYNC_CHECKPOINT_THREAD.join()
1444
1445    reader = py_checkpoint_reader.NewCheckpointReader(save_path)
1446    graph_building = not context.executing_eagerly()
1447    if graph_building:
1448      dtype_map = None
1449    else:
1450      dtype_map = reader.get_variable_to_dtype_map()
1451    try:
1452      object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
1453    except errors_impl.NotFoundError:
1454      # The object graph proto does not exist in this checkpoint. Try the
1455      # name-based compatibility mode.
1456      restore_coordinator = _NameBasedRestoreCoordinator(
1457          save_path=save_path,
1458          dtype_map=dtype_map)
1459      if not graph_building:
1460        for existing_trackable in util.list_objects(self._graph_view):
1461          # pylint: disable=protected-access
1462          existing_trackable._maybe_initialize_trackable()
1463          existing_trackable._name_based_restores.add(restore_coordinator)
1464          existing_trackable._name_based_attribute_restore(restore_coordinator)
1465          # pylint: enable=protected-access
1466      return NameBasedSaverStatus(
1467          restore_coordinator,
1468          object_graph_view=self._graph_view)
1469
1470    if graph_building:
1471      if self._file_prefix_placeholder is None:
1472        with ops.device("/cpu:0"):
1473          self._file_prefix_placeholder = constant_op.constant("model")
1474      file_prefix_tensor = self._file_prefix_placeholder
1475      file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
1476    else:
1477      with ops.device("/cpu:0"):
1478        file_prefix_tensor = constant_op.constant(save_path)
1479      file_prefix_feed_dict = None
1480    object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1481    object_graph_proto.ParseFromString(object_graph_string)
1482    checkpoint = _CheckpointRestoreCoordinator(
1483        object_graph_proto=object_graph_proto,
1484        save_path=save_path,
1485        save_path_tensor=file_prefix_tensor,
1486        reader=reader,
1487        restore_op_cache=self._restore_op_cache,
1488        graph_view=self._graph_view,
1489        options=options,
1490        saveables_cache=self._saveables_cache)
1491    restore_lib.CheckpointPosition(
1492        checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
1493
1494    # Attached dependencies are not attached to the root, so should be restored
1495    # separately.
1496    if self._graph_view.attached_dependencies:
1497      for ref in self._graph_view.attached_dependencies:
1498        if ref.name == "root":
1499          # Root dependency is automatically added to attached dependencies --
1500          # this can be ignored since it maps back to the root object.
1501          continue
1502        proto_id = None
1503        # Find proto ID of attached dependency (if it is in the proto).
1504        for proto_ref in object_graph_proto.nodes[0].children:
1505          if proto_ref.local_name == ref.name:
1506            proto_id = proto_ref.node_id
1507            break
1508
1509        if proto_id in checkpoint.object_by_proto_id:
1510          # Object has already been restored. This can happen when there's an
1511          # indirect connection from the attached object to the root.
1512          continue
1513
1514        if proto_id is None:
1515          # Could not find attached dependency in proto.
1516          continue
1517
1518        restore_lib.CheckpointPosition(
1519            checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref)
1520
1521    load_status = CheckpointLoadStatus(
1522        checkpoint,
1523        graph_view=self._graph_view,
1524        feed_dict=file_prefix_feed_dict)
1525    return load_status
1526
1527
1528def frozen_saver(root_trackable):
1529  """Creates a static `tf.compat.v1.train.Saver` from a trackable object.
1530
1531  The returned `Saver` saves object-based checkpoints, but these checkpoints
1532  will no longer reflect structural changes to the object graph, only changes to
1533  the values of `Variable`s added as dependencies of the root object before
1534  `freeze` was called.
1535
1536  `restore` works on the returned `Saver`, but requires that the object graph of
1537  the checkpoint being loaded exactly matches the object graph when `freeze` was
1538  called. This is in contrast the object-based restore performed by
1539  `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
1540  object graph and the current Python object graph.
1541
1542  Args:
1543    root_trackable: A trackable object to save.
1544
1545  Returns:
1546    A saver which saves object-based checkpoints for the object graph frozen at
1547    the time `frozen_saver` was called.
1548  """
1549  named_saveable_objects, registered_savers = (
1550      save_util_v1.frozen_saveables_and_savers(
1551          graph_view_lib.ObjectGraphView(root_trackable)))
1552  return functional_saver.MultiDeviceSaver(named_saveable_objects,
1553                                           registered_savers)
1554
1555
1556def _assert_trackable(obj, name):
1557  if not isinstance(
1558      obj, (base.Trackable, def_function.Function)):
1559    raise ValueError(
1560        f"`Checkpoint` was expecting {name} to be a trackable object (an "
1561        f"object derived from `Trackable`), got {obj}. If you believe this "
1562        "object should be trackable (i.e. it is part of the "
1563        "TensorFlow Python API and manages state), please open an issue.")
1564
1565
1566def _update_checkpoint_state_internal(file_path):
1567  """Update internal checkpoint state."""
1568  checkpoint_management.update_checkpoint_state_internal(
1569      save_dir=os.path.dirname(file_path),
1570      model_checkpoint_path=file_path,
1571      all_model_checkpoint_paths=[file_path],
1572      save_relative_paths=True)
1573
1574
1575def _convert_file_name_tensor_to_string(tensor):
1576  """Convert file name tensor to string."""
1577  output = tensor
1578  if tensor_util.is_tf_type(output):
1579    # Convert to numpy if not `tf.function` building.
1580    if context.executing_eagerly():
1581      output = compat.as_str(output.numpy())
1582  else:
1583    # Graph + Session, so we already session.ran it.
1584    output = compat.as_str(output)
1585  return output
1586
1587
1588# Mentions graph building / Sessions. The v2 version is below.
1589@tf_export(v1=["train.Checkpoint"])
1590class CheckpointV1(autotrackable.AutoTrackable):
1591  """Groups trackable objects, saving and restoring them.
1592
1593  `Checkpoint`'s constructor accepts keyword arguments whose values are types
1594  that contain trackable state, such as `tf.compat.v1.train.Optimizer`
1595  implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
1596  `tf.keras.Model` implementations. It saves these values with a checkpoint, and
1597  maintains a `save_counter` for numbering checkpoints.
1598
1599  Example usage when graph building:
1600
1601  ```python
1602  import tensorflow as tf
1603  import os
1604
1605  checkpoint_directory = "/tmp/training_checkpoints"
1606  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1607
1608  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1609  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1610  train_op = optimizer.minimize( ... )
1611  status.assert_consumed()  # Optional sanity checks.
1612  with tf.compat.v1.Session() as session:
1613    # Use the Session to restore variables, or initialize them if
1614    # tf.train.latest_checkpoint returned None.
1615    status.initialize_or_restore(session)
1616    for _ in range(num_training_steps):
1617      session.run(train_op)
1618    checkpoint.save(file_prefix=checkpoint_prefix)
1619  ```
1620
1621  Example usage with eager execution enabled:
1622
1623  ```python
1624  import tensorflow as tf
1625  import os
1626
1627  tf.compat.v1.enable_eager_execution()
1628
1629  checkpoint_directory = "/tmp/training_checkpoints"
1630  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1631
1632  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1633  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1634  for _ in range(num_training_steps):
1635    optimizer.minimize( ... )  # Variables will be restored on creation.
1636  status.assert_consumed()  # Optional sanity checks.
1637  checkpoint.save(file_prefix=checkpoint_prefix)
1638  ```
1639
1640  `Checkpoint.save` and `Checkpoint.restore` write and read object-based
1641  checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
1642  `variable.name` based checkpoints. Object-based checkpointing saves a graph of
1643  dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
1644  etc.) with named edges, and this graph is used to match variables when
1645  restoring a checkpoint. It can be more robust to changes in the Python
1646  program, and helps to support restore-on-create for variables when executing
1647  eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
1648  code.
1649
1650  `Checkpoint` objects have dependencies on the objects passed as keyword
1651  arguments to their constructors, and each dependency is given a name that is
1652  identical to the name of the keyword argument for which it was created.
1653  TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
1654  dependencies on their variables (e.g. "kernel" and "bias" for
1655  `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
1656  dependencies easy in user-defined classes, since `Model` hooks into attribute
1657  assignment. For example:
1658
1659  ```python
1660  class Regress(tf.keras.Model):
1661
1662    def __init__(self):
1663      super().__init__()
1664      self.input_transform = tf.keras.layers.Dense(10)
1665      # ...
1666
1667    def call(self, inputs):
1668      x = self.input_transform(inputs)
1669      # ...
1670  ```
1671
1672  This `Model` has a dependency named "input_transform" on its `Dense` layer,
1673  which in turn depends on its variables. As a result, saving an instance of
1674  `Regress` using `tf.train.Checkpoint` will also save all the variables created
1675  by the `Dense` layer.
1676
1677  When variables are assigned to multiple workers, each worker writes its own
1678  section of the checkpoint. These sections are then merged/re-indexed to behave
1679  as a single checkpoint. This avoids copying all variables to one worker, but
1680  does require that all workers see a common filesystem.
1681
1682  While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
1683  same format, note that the root of the resulting checkpoint is the object the
1684  save method is attached to. This means saving a `tf.keras.Model` using
1685  `save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
1686  attached (or vice versa) will not match the `Model`'s variables. See the
1687  [guide to training
1688  checkpoints](https://www.tensorflow.org/guide/checkpoint) for
1689  details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
1690  training checkpoints.
1691
1692  Attributes:
1693    save_counter: Incremented when `save()` is called. Used to number
1694      checkpoints.
1695  """
1696
1697  def __init__(self, **kwargs):
1698    """Group objects into a training checkpoint.
1699
1700    Args:
1701      **kwargs: Keyword arguments are set as attributes of this object, and are
1702        saved with the checkpoint. Values must be trackable objects.
1703
1704    Raises:
1705      ValueError: If objects in `kwargs` are not trackable.
1706    """
1707    super().__init__()
1708    global _END_TIME_OF_LAST_WRITE
1709    with _END_TIME_OF_LAST_WRITE_LOCK:
1710      if _END_TIME_OF_LAST_WRITE is None:
1711        _END_TIME_OF_LAST_WRITE = time.time()
1712
1713    for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
1714      setattr(self, k, v)
1715      if not isinstance(
1716          getattr(self, k), (base.Trackable, def_function.Function)):
1717        raise ValueError(
1718            "`Checkpoint` was expecting a trackable object (an object "
1719            f"derived from `Trackable`), got {v}. If you believe this "
1720            "object should be trackable (i.e. it is part of the "
1721            "TensorFlow Python API and manages state), please open an issue.")
1722    self._save_counter = None  # Created lazily for restore-on-create.
1723    self._save_assign_op = None
1724    self._saver = TrackableSaver(graph_view_lib.ObjectGraphView(self))
1725
1726  def _maybe_create_save_counter(self):
1727    """Create a save counter if it does not yet exist."""
1728    if self._save_counter is None:
1729      # Initialized to 0 and incremented before saving.
1730      with ops.device("/cpu:0"):
1731        # add_variable creates a dependency named "save_counter"; NoDependency
1732        # prevents creating a second dependency named "_save_counter".
1733        self._save_counter = data_structures.NoDependency(
1734            add_variable(
1735                self,
1736                name="save_counter",
1737                initializer=0,
1738                dtype=dtypes.int64,
1739                trainable=False))
1740
1741  def write(self, file_prefix, session=None):
1742    """Writes a training checkpoint.
1743
1744    The checkpoint includes variables created by this object and any
1745    trackable objects it depends on at the time `Checkpoint.write()` is
1746    called.
1747
1748    `write` does not number checkpoints, increment `save_counter`, or update the
1749    metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1750    use by higher level checkpoint management utilities. `save` provides a very
1751    basic implementation of these features.
1752
1753    Args:
1754      file_prefix: A prefix to use for the checkpoint filenames
1755        (/path/to/directory/and_a_prefix).
1756      session: The session to evaluate variables in. Ignored when executing
1757        eagerly. If not provided when graph building, the default session is
1758        used.
1759
1760    Returns:
1761      The full path to the checkpoint (i.e. `file_prefix`).
1762    """
1763    return self._write(file_prefix, session)
1764
1765  def _write(self, file_prefix, session=None, write_done_callback=None):
1766    """Writes a training checkpoint.
1767
1768    The checkpoint includes variables created by this object and any
1769    trackable objects it depends on at the time `Checkpoint.write()` is
1770    called.
1771
1772    `write` does not number checkpoints, increment `save_counter`, or update the
1773    metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1774    use by higher level checkpoint management utilities. `save` provides a very
1775    basic implementation of these features.
1776
1777    Args:
1778      file_prefix: A prefix to use for the checkpoint filenames
1779        (/path/to/directory/and_a_prefix).
1780      session: The session to evaluate variables in. Ignored when executing
1781        eagerly. If not provided when graph building, the default session is
1782        used.
1783      write_done_callback: Optional callback function to be executed once
1784        the underlying checkpoint saving is finished. Example usage includes
1785        updating the checkpoint internal state.
1786
1787    Returns:
1788      The full path to the checkpoint (i.e. `file_prefix`).
1789    """
1790    start_time = time.time()
1791    output = self._saver.save(file_prefix=file_prefix, session=session)
1792    end_time = time.time()
1793
1794    metrics.AddCheckpointWriteDuration(
1795        api_label=_CHECKPOINT_V1,
1796        microseconds=_get_duration_microseconds(start_time, end_time))
1797
1798    global _END_TIME_OF_LAST_WRITE
1799    with _END_TIME_OF_LAST_WRITE_LOCK:
1800      metrics.AddTrainingTimeSaved(
1801          api_label=_CHECKPOINT_V1,
1802          microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE,
1803                                                  end_time))
1804      _END_TIME_OF_LAST_WRITE = end_time
1805
1806    if tensor_util.is_tf_type(output):
1807      # Convert to numpy if not `tf.function` building.
1808      if context.executing_eagerly():
1809        output = compat.as_str(output.numpy())
1810    else:
1811      # Graph + Session, so we already session.ran it.
1812      output = compat.as_str(output)
1813
1814    if write_done_callback:
1815      write_done_callback(output)
1816
1817    metrics.RecordCheckpointSize(
1818        api_label=_CHECKPOINT_V1, filesize=_get_checkpoint_size(output))
1819    return output
1820
1821  @property
1822  def save_counter(self):
1823    """An integer variable which starts at zero and is incremented on save.
1824
1825    Used to number checkpoints.
1826
1827    Returns:
1828      The save counter variable.
1829    """
1830    self._maybe_create_save_counter()
1831    return self._save_counter
1832
1833  def save(self, file_prefix, session=None):
1834    """Saves a training checkpoint and provides basic checkpoint management.
1835
1836    The saved checkpoint includes variables created by this object and any
1837    trackable objects it depends on at the time `Checkpoint.save()` is
1838    called.
1839
1840    `save` is a basic convenience wrapper around the `write` method,
1841    sequentially numbering checkpoints using `save_counter` and updating the
1842    metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
1843    management, for example garbage collection and custom numbering, may be
1844    provided by other utilities which also wrap `write`
1845    (`tf.train.CheckpointManager` for example).
1846
1847    Args:
1848      file_prefix: A prefix to use for the checkpoint filenames
1849        (/path/to/directory/and_a_prefix). Names are generated based on this
1850        prefix and `Checkpoint.save_counter`.
1851      session: The session to evaluate variables in. Ignored when executing
1852        eagerly. If not provided when graph building, the default session is
1853        used.
1854
1855    Returns:
1856      The full path to the checkpoint.
1857    """
1858    graph_building = not context.executing_eagerly()
1859    if graph_building:
1860      if ops.inside_function():
1861        raise NotImplementedError(
1862            "Calling tf.train.Checkpoint.save() from a function is not "
1863            "supported, as save() modifies saving metadata in ways not "
1864            "supported by TensorFlow Operations. Consider using "
1865            "tf.train.Checkpoint.write(), a lower-level API which does not "
1866            "update metadata. tf.train.latest_checkpoint and related APIs will "
1867            "not see this checkpoint.")
1868      if session is None:
1869        session = get_session()
1870      if self._save_counter is None:
1871        # When graph building, if this is a new save counter variable then it
1872        # needs to be initialized before assign_add. This is only an issue if
1873        # restore() has not been called first.
1874        session.run(self.save_counter.initializer)
1875    if not graph_building or self._save_assign_op is None:
1876      with ops.colocate_with(self.save_counter):
1877        assign_op = self.save_counter.assign_add(1, read_value=True)
1878      if graph_building:
1879        self._save_assign_op = data_structures.NoDependency(assign_op)
1880    if graph_building:
1881      checkpoint_number = session.run(self._save_assign_op)
1882    else:
1883      checkpoint_number = assign_op.numpy()
1884    file_path = self.write(
1885        "%s-%d" % (file_prefix, checkpoint_number), session=session)
1886    checkpoint_management.update_checkpoint_state_internal(
1887        save_dir=os.path.dirname(file_prefix),
1888        model_checkpoint_path=file_path,
1889        all_model_checkpoint_paths=[file_path],
1890        save_relative_paths=True)
1891    return file_path
1892
1893  def restore(self, save_path):
1894    """Restore a training checkpoint.
1895
1896    Restores this `Checkpoint` and any objects it depends on.
1897
1898    When executing eagerly, either assigns values immediately if variables to
1899    restore have been created already, or defers restoration until the variables
1900    are created. Dependencies added after this call will be matched if they have
1901    a corresponding object in the checkpoint (the restore request will queue in
1902    any trackable object waiting for the expected dependency to be added).
1903
1904    When graph building, restoration ops are added to the graph but not run
1905    immediately.
1906
1907    ```python
1908    checkpoint = tf.train.Checkpoint( ... )
1909    checkpoint.restore(path)
1910    ```
1911
1912    To ensure that loading is complete and no more deferred restorations will
1913    take place, you can use the `assert_consumed()` method of the status object
1914    returned by `restore`.
1915    The assert will raise an exception if any Python objects in the dependency
1916    graph were not found in the checkpoint, or if any checkpointed values do not
1917    have a matching Python object:
1918
1919    ```python
1920    checkpoint = tf.train.Checkpoint( ... )
1921    checkpoint.restore(path).assert_consumed()
1922    ```
1923
1924    When graph building, `assert_consumed()` indicates that all of the restore
1925    ops that will be created for this checkpoint have been created. They can be
1926    run via the `run_restore_ops()` method of the status object:
1927
1928    ```python
1929    checkpoint.restore(path).assert_consumed().run_restore_ops()
1930    ```
1931
1932    If the checkpoint has not been consumed completely, then the list of restore
1933    ops will grow as more objects are added to the dependency graph.
1934
1935    To check that all variables in the Python object have restored values from
1936    checkpoint, use `assert_existing_objects_matched()`. This assertion is
1937    useful when called after the variables in your graph have been created.
1938
1939    Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1940    method. Names are used to match variables. No restore ops are created/run
1941    until `run_restore_ops()` or `initialize_or_restore()` are called on the
1942    returned status object when graph building, but there is restore-on-creation
1943    when executing eagerly. Re-encode name-based checkpoints using
1944    `tf.train.Checkpoint.save` as soon as possible.
1945
1946    Args:
1947      save_path: The path to the checkpoint, as returned by `save` or
1948        `tf.train.latest_checkpoint`. If None (as when there is no latest
1949        checkpoint for `tf.train.latest_checkpoint` to return), returns an
1950        object which may run initializers for objects in the dependency graph.
1951        If the checkpoint was written by the name-based
1952        `tf.compat.v1.train.Saver`, names are used to match variables.
1953
1954    Returns:
1955      A load status object, which can be used to make assertions about the
1956      status of a checkpoint restoration and run initialization/restore ops.
1957
1958      The returned status object has the following methods:
1959
1960      * `assert_consumed()`:
1961          Raises an exception if any variables are unmatched: either
1962          checkpointed values which don't have a matching Python object or
1963          Python objects in the dependency graph with no values in the
1964          checkpoint. This method returns the status object, and so may be
1965          chained with `initialize_or_restore` or `run_restore_ops`.
1966
1967      * `assert_existing_objects_matched()`:
1968          Raises an exception if any existing Python objects in the dependency
1969          graph are unmatched. Unlike `assert_consumed`, this assertion will
1970          pass if values in the checkpoint have no corresponding Python
1971          objects. For example a `tf.keras.Layer` object which has not yet been
1972          built, and so has not created any variables, will pass this assertion
1973          but will fail `assert_consumed`. Useful when loading part of a larger
1974          checkpoint into a new Python program, e.g. a training checkpoint with
1975          a `tf.compat.v1.train.Optimizer` was saved but only the state required
1976          for inference is being loaded. This method returns the status object,
1977          and so may be chained with `initialize_or_restore` or
1978          `run_restore_ops`.
1979
1980      * `assert_nontrivial_match()`: Asserts that something aside from the root
1981          object was matched. This is a very weak assertion, but is useful for
1982          sanity checking in library code where objects may exist in the
1983          checkpoint which haven't been created in Python and some Python
1984          objects may not have a checkpointed value.
1985
1986      * `expect_partial()`: Silence warnings about incomplete checkpoint
1987          restores. Warnings are otherwise printed for unused parts of the
1988          checkpoint file or object when the `Checkpoint` object is deleted
1989          (often at program shutdown).
1990
1991      * `initialize_or_restore(session=None)`:
1992          When graph building, runs variable initializers if `save_path` is
1993          `None`, but otherwise runs restore operations. If no `session` is
1994          explicitly specified, the default session is used. No effect when
1995          executing eagerly (variables are initialized or restored eagerly).
1996
1997      * `run_restore_ops(session=None)`:
1998          When graph building, runs restore operations. If no `session` is
1999          explicitly specified, the default session is used. No effect when
2000          executing eagerly (restore operations are run eagerly). May only be
2001          called when `save_path` is not `None`.
2002    """
2003    start_time = time.time()
2004    status = self._saver.restore(save_path=save_path)
2005    # Create the save counter now so it gets initialized with other variables
2006    # when graph building. Creating it earlier would lead to errors when using,
2007    # say, train.Saver() to save the model before initializing it.
2008    self._maybe_create_save_counter()
2009    if isinstance(status, NameBasedSaverStatus):
2010      status.add_to_optionally_restored(self.save_counter)
2011
2012    metrics.AddCheckpointReadDuration(
2013        api_label=_CHECKPOINT_V1,
2014        microseconds=_get_duration_microseconds(start_time, time.time()))
2015    return status
2016
2017
2018@tf_export("train.Checkpoint", v1=[])
2019class Checkpoint(autotrackable.AutoTrackable):
2020  """Manages saving/restoring trackable values to disk.
2021
2022  TensorFlow objects may contain trackable state, such as `tf.Variable`s,
2023  `tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators,
2024  `tf.keras.Layer` implementations, or  `tf.keras.Model` implementations.
2025  These are called **trackable objects**.
2026
2027  A `Checkpoint` object can be constructed to save either a single or group of
2028  trackable objects to a checkpoint file. It maintains a `save_counter` for
2029  numbering checkpoints.
2030
2031  Example:
2032
2033  ```python
2034  model = tf.keras.Model(...)
2035  checkpoint = tf.train.Checkpoint(model)
2036
2037  # Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
2038  # checkpoint.save is called, the save counter is increased.
2039  save_path = checkpoint.save('/tmp/training_checkpoints')
2040
2041  # Restore the checkpointed values to the `model` object.
2042  checkpoint.restore(save_path)
2043  ```
2044
2045  Example 2:
2046
2047  ```python
2048  import tensorflow as tf
2049  import os
2050
2051  checkpoint_directory = "/tmp/training_checkpoints"
2052  checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
2053
2054  # Create a Checkpoint that will manage two objects with trackable state,
2055  # one we name "optimizer" and the other we name "model".
2056  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
2057  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
2058  for _ in range(num_training_steps):
2059    optimizer.minimize( ... )  # Variables will be restored on creation.
2060  status.assert_consumed()  # Optional sanity checks.
2061  checkpoint.save(file_prefix=checkpoint_prefix)
2062  ```
2063
2064  `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based
2065  checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
2066  writes and
2067  reads `variable.name` based checkpoints. Object-based checkpointing saves a
2068  graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
2069  `Variable`s, etc.) with named edges, and this graph is used to match variables
2070  when restoring a checkpoint. It can be more robust to changes in the Python
2071  program, and helps to support restore-on-create for variables.
2072
2073  `Checkpoint` objects have dependencies on the objects passed as keyword
2074  arguments to their constructors, and each dependency is given a name that is
2075  identical to the name of the keyword argument for which it was created.
2076  TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
2077  dependencies on their own variables (e.g. "kernel" and "bias" for
2078  `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
2079  dependencies easy in user-defined classes, since `Model` hooks into attribute
2080  assignment. For example:
2081
2082  ```python
2083  class Regress(tf.keras.Model):
2084
2085    def __init__(self):
2086      super().__init__()
2087      self.input_transform = tf.keras.layers.Dense(10)
2088      # ...
2089
2090    def call(self, inputs):
2091      x = self.input_transform(inputs)
2092      # ...
2093  ```
2094
2095  This `Model` has a dependency named "input_transform" on its `Dense` layer,
2096  which in turn depends on its variables. As a result, saving an instance of
2097  `Regress` using `tf.train.Checkpoint` will also save all the variables created
2098  by the `Dense` layer.
2099
2100  When variables are assigned to multiple workers, each worker writes its own
2101  section of the checkpoint. These sections are then merged/re-indexed to behave
2102  as a single checkpoint. This avoids copying all variables to one worker, but
2103  does require that all workers see a common filesystem.
2104
2105  This function differs slightly from the Keras Model `save_weights` function.
2106  `tf.keras.Model.save_weights` creates a checkpoint file with the name
2107  specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints,
2108  using `filepath` as the prefix for the checkpoint file names. Aside from this,
2109  `model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent.
2110
2111  See the [guide to training
2112  checkpoints](https://www.tensorflow.org/guide/checkpoint) for
2113  details.
2114
2115  Attributes:
2116    save_counter: Incremented when `save()` is called. Used to number
2117      checkpoints.
2118  """
2119
2120  def __init__(self, root=None, **kwargs):
2121    """Creates a training checkpoint for a single or group of objects.
2122
2123    Args:
2124      root: The root object to checkpoint. `root` may be a trackable object or
2125        `WeakRef` of a trackable object.
2126      **kwargs: Keyword arguments are set as attributes of this object, and are
2127        saved with the checkpoint. All `kwargs` must be trackable objects, or a
2128        nested structure of trackable objects (`list`, `dict`, or `tuple`).
2129
2130    Raises:
2131      ValueError: If `root` or the objects in `kwargs` are not trackable. A
2132        `ValueError` is also raised if the `root` object tracks different
2133        objects from the ones listed in attributes in kwargs (e.g.
2134        `root.child = A` and `tf.train.Checkpoint(root, child=B)` are
2135        incompatible).
2136
2137    """
2138    super().__init__()
2139    global _END_TIME_OF_LAST_WRITE
2140    with _END_TIME_OF_LAST_WRITE_LOCK:
2141      if _END_TIME_OF_LAST_WRITE is None:
2142        _END_TIME_OF_LAST_WRITE = time.time()
2143
2144    attached_dependencies = None
2145    self._save_counter = None  # Created lazily for restore-on-create.
2146    self._save_assign_op = None
2147
2148    if root:
2149      trackable_root = root() if isinstance(root, weakref.ref) else root
2150      _assert_trackable(trackable_root, "root")
2151      attached_dependencies = []
2152
2153      # All keyword arguments (including root itself) are set as children
2154      # of root.
2155      kwargs["root"] = root
2156      trackable_root._maybe_initialize_trackable()
2157
2158      self._save_counter = data_structures.NoDependency(
2159          trackable_root._lookup_dependency("save_counter"))
2160
2161    for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
2162      setattr(self, k, v)
2163
2164      # Call getattr instead of directly using v because setattr converts
2165      # v to a Trackable data structure when v is a list/dict/tuple.
2166      converted_v = getattr(self, k)
2167      if isinstance(converted_v, weakref.ref):
2168        converted_v = converted_v()
2169      _assert_trackable(converted_v, k)
2170
2171      if root:
2172        # Make sure that root doesn't already have dependencies with these names
2173        child = trackable_root._lookup_dependency(k)
2174        if child is None:
2175          attached_dependencies.append(
2176              base.WeakTrackableReference(k, converted_v))
2177        elif child != converted_v:
2178          raise ValueError(
2179              f"Cannot create a Checkpoint with keyword argument {k} if "
2180              f"root.{k} already exists.")
2181
2182    self._saver = TrackableSaver(
2183        graph_view_lib.ObjectGraphView(
2184            root if root else self,
2185            attached_dependencies=attached_dependencies))
2186    self._attached_dependencies = data_structures.NoDependency(
2187        attached_dependencies)
2188
2189  def _maybe_create_save_counter(self):
2190    """Create a save counter if it does not yet exist."""
2191    if self._save_counter is None:
2192      # Initialized to 0 and incremented before saving.
2193      with ops.device("/cpu:0"):
2194        # add_variable creates a dependency named "save_counter"; NoDependency
2195        # prevents creating a second dependency named "_save_counter".
2196        self._save_counter = data_structures.NoDependency(
2197            add_variable(
2198                self,
2199                name="save_counter",
2200                initializer=0,
2201                dtype=dtypes.int64,
2202                trainable=False))
2203        if self._attached_dependencies is not None:
2204          self._attached_dependencies.append(
2205              # Store a stronge reference to the `save_counter`, so that if the
2206              # `Checkpoint` object is deleted, the `save_counter` does not get
2207              # deleted immediately. (The LoadStatus object needs to indirectly
2208              # reference the counter through the ObjectGraphView).
2209              base.TrackableReference("save_counter", self._save_counter))
2210          # When loading a checkpoint, the save counter is created after
2211          # the checkpoint has been loaded, so it must be handled in a deferred
2212          # manner.
2213          if isinstance(self.root, weakref.ref):
2214            root = self.root()
2215          else:
2216            root = self.root
2217          restore = root._deferred_dependencies.pop("save_counter", ())  # pylint: disable=protected-access
2218          if restore:
2219            restore[0].restore(self._save_counter)
2220
2221  def write(self, file_prefix, options=None):
2222    """Writes a training checkpoint.
2223
2224    The checkpoint includes variables created by this object and any
2225    trackable objects it depends on at the time `Checkpoint.write()` is
2226    called.
2227
2228    `write` does not number checkpoints, increment `save_counter`, or update the
2229    metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
2230    use by higher level checkpoint management utilities. `save` provides a very
2231    basic implementation of these features.
2232
2233    Checkpoints written with `write` must be read with `read`.
2234
2235    Example usage:
2236
2237    ```
2238    step = tf.Variable(0, name="step")
2239    checkpoint = tf.Checkpoint(step=step)
2240    checkpoint.write("/tmp/ckpt")
2241
2242    # Later, read the checkpoint with read()
2243    checkpoint.read("/tmp/ckpt")
2244
2245    # You can also pass options to write() and read(). For example this
2246    # runs the IO ops on the localhost:
2247    options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
2248    checkpoint.write("/tmp/ckpt", options=options)
2249
2250    # Later, read the checkpoint with read()
2251    checkpoint.read("/tmp/ckpt", options=options)
2252    ```
2253
2254    Args:
2255      file_prefix: A prefix to use for the checkpoint filenames
2256        (/path/to/directory/and_a_prefix).
2257      options: Optional `tf.train.CheckpointOptions` object.
2258
2259    Returns:
2260      The full path to the checkpoint (i.e. `file_prefix`).
2261    """
2262    if isinstance(file_prefix, os.PathLike):
2263      file_prefix = os.fspath(file_prefix)
2264    return self._write(file_prefix, options)
2265
2266  def _write(self, file_prefix, options=None, write_done_callback=None):
2267    """Internal method that implements Checkpoint.write().
2268
2269    Args:
2270      file_prefix: A prefix to use for the checkpoint filenames
2271        (/path/to/directory/and_a_prefix).
2272      options: Optional `tf.train.CheckpointOptions` object.
2273      write_done_callback: Optional callback function to be executed once
2274        the underlying checkpoint saving is finished. Example usage includes
2275        updating the checkpoint internal state.
2276
2277    Returns:
2278      The full path to the checkpoint (i.e. `file_prefix`).
2279    """
2280    start_time = time.time()
2281    options = options or checkpoint_options.CheckpointOptions()
2282    output = self._saver.save(
2283        file_prefix=file_prefix,
2284        options=options,
2285        write_done_callback=write_done_callback)
2286    output = _convert_file_name_tensor_to_string(output)
2287    # For async-checkpoint-v1 scenario, the callback function is triggered in
2288    # TrackableSaver. For others, it is executed here.
2289    if not options.experimental_enable_async_checkpoint and write_done_callback:
2290      write_done_callback(output)
2291    end_time = time.time()
2292
2293    # This records the time checkpoint._write() blocks on the main thread.
2294    metrics.AddCheckpointWriteDuration(
2295        api_label=_CHECKPOINT_V2,
2296        microseconds=_get_duration_microseconds(start_time, end_time))
2297
2298    global _END_TIME_OF_LAST_WRITE
2299    with _END_TIME_OF_LAST_WRITE_LOCK:
2300      metrics.AddTrainingTimeSaved(
2301          api_label=_CHECKPOINT_V2,
2302          microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE,
2303                                                  end_time))
2304      _END_TIME_OF_LAST_WRITE = end_time
2305
2306    # Async-checkpoint-v1 may not have finished yet, so we can't measure its
2307    # checkpoint size now. For other scenarios (sync-checkpoint and async-v2),
2308    # this metric recording works fine here.
2309    if not options.experimental_enable_async_checkpoint:
2310      metrics.RecordCheckpointSize(
2311          api_label=_CHECKPOINT_V2, filesize=_get_checkpoint_size(output))
2312    return output
2313
2314  @property
2315  def save_counter(self):
2316    """An integer variable which starts at zero and is incremented on save.
2317
2318    Used to number checkpoints.
2319
2320    Returns:
2321      The save counter variable.
2322    """
2323    self._maybe_create_save_counter()
2324    return self._save_counter
2325
2326  def save(self, file_prefix, options=None):
2327    # pylint:disable=line-too-long
2328    """Saves a training checkpoint and provides basic checkpoint management.
2329
2330    The saved checkpoint includes variables created by this object and any
2331    trackable objects it depends on at the time `Checkpoint.save()` is
2332    called.
2333
2334    `save` is a basic convenience wrapper around the `write` method,
2335    sequentially numbering checkpoints using `save_counter` and updating the
2336    metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
2337    management, for example garbage collection and custom numbering, may be
2338    provided by other utilities which also wrap `write` and `read`.
2339    (`tf.train.CheckpointManager` for example).
2340
2341    ```
2342    step = tf.Variable(0, name="step")
2343    checkpoint = tf.train.Checkpoint(step=step)
2344    checkpoint.save("/tmp/ckpt")
2345
2346    # Later, read the checkpoint with restore()
2347    checkpoint.restore("/tmp/ckpt-1")
2348
2349    # You can also pass options to save() and restore(). For example this
2350    # runs the IO ops on the localhost:
2351    options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
2352    checkpoint.save("/tmp/ckpt", options=options)
2353
2354    # Later, read the checkpoint with restore()
2355    checkpoint.restore("/tmp/ckpt-1", options=options)
2356    ```
2357
2358    Args:
2359      file_prefix: A prefix to use for the checkpoint filenames
2360        (/path/to/directory/and_a_prefix). Names are generated based on this
2361        prefix and `Checkpoint.save_counter`.
2362      options: Optional `tf.train.CheckpointOptions` object.
2363
2364    Returns:
2365      The full path to the checkpoint.
2366    """
2367    if isinstance(file_prefix, os.PathLike):
2368      file_prefix = os.fspath(file_prefix)
2369    # pylint:enable=line-too-long
2370    options = options or checkpoint_options.CheckpointOptions()
2371    graph_building = not context.executing_eagerly()
2372    if graph_building:
2373      # Assert that async checkpoint is not used for non-eager mode.
2374      if options.experimental_enable_async_checkpoint:
2375        raise NotImplementedError(
2376            "Async checkpoint is not supported for non-eager mode. ")
2377
2378      if ops.inside_function():
2379        raise NotImplementedError(
2380            "Calling tf.train.Checkpoint.save() from a function is not "
2381            "supported, as save() modifies saving metadata in ways not "
2382            "supported by TensorFlow Operations. Consider using "
2383            "tf.train.Checkpoint.write(), a lower-level API which does not "
2384            "update metadata. tf.train.latest_checkpoint and related APIs will "
2385            "not see this checkpoint.")
2386      session = get_session()
2387      if self._save_counter is None:
2388        # When graph building, if this is a new save counter variable then it
2389        # needs to be initialized before assign_add. This is only an issue if
2390        # restore() has not been called first.
2391        session.run(self.save_counter.initializer)
2392
2393    if not graph_building or self._save_assign_op is None:
2394      with ops.colocate_with(self.save_counter):
2395        assign_op = self.save_counter.assign_add(1, read_value=True)
2396      if graph_building:
2397        self._save_assign_op = data_structures.NoDependency(assign_op)
2398
2399    if graph_building:
2400      checkpoint_number = session.run(self._save_assign_op)
2401    else:
2402      checkpoint_number = assign_op.numpy()
2403
2404    file_path = self._write(
2405        "%s-%d" % (file_prefix, checkpoint_number),
2406        options=options,
2407        write_done_callback=_update_checkpoint_state_internal)
2408
2409    # Ensure save operations have completed, only when running in eager runtime
2410    # and non-async checkpoint configuration.
2411    if not graph_building and not options.experimental_enable_async_checkpoint:
2412      context.async_wait()
2413
2414    return file_path
2415
2416  def read(self, save_path, options=None):
2417    """Reads a training checkpoint written with `write`.
2418
2419    Reads this `Checkpoint` and any objects it depends on.
2420
2421    This method is just like `restore()` but does not expect the `save_counter`
2422    variable in the checkpoint. It only restores the objects that the checkpoint
2423    already depends on.
2424
2425    The method is primarily intended for use by higher level checkpoint
2426    management utilities that use `write()` instead of `save()` and have their
2427    own mechanisms to number and track checkpoints.
2428
2429    Example usage:
2430
2431    ```python
2432    # Create a checkpoint with write()
2433    ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
2434    path = ckpt.write('/tmp/my_checkpoint')
2435
2436    # Later, load the checkpoint with read()
2437    # With restore() assert_consumed() would have failed.
2438    checkpoint.read(path).assert_consumed()
2439
2440    # You can also pass options to read(). For example this
2441    # runs the IO ops on the localhost:
2442    options = tf.train.CheckpointOptions(
2443        experimental_io_device="/job:localhost")
2444    checkpoint.read(path, options=options)
2445    ```
2446
2447    Args:
2448      save_path: The path to the checkpoint as returned by `write`.
2449      options: Optional `tf.train.CheckpointOptions` object.
2450
2451    Returns:
2452      A load status object, which can be used to make assertions about the
2453      status of a checkpoint restoration.  See `restore` for details.
2454    """
2455    start_time = time.time()
2456    if isinstance(save_path, os.PathLike):
2457      save_path = os.fspath(save_path)
2458    options = options or checkpoint_options.CheckpointOptions()
2459    result = self._saver.restore(save_path=save_path, options=options)
2460    metrics.AddCheckpointReadDuration(
2461        api_label=_CHECKPOINT_V2,
2462        microseconds=_get_duration_microseconds(start_time, time.time()))
2463    return result
2464
2465  def restore(self, save_path, options=None):
2466    """Restores a training checkpoint.
2467
2468    Restores this `Checkpoint` and any objects it depends on.
2469
2470    This method is intended to be used to load checkpoints created by `save()`.
2471    For checkpoints created by `write()` use the `read()` method which does not
2472    expect the `save_counter` variable added by `save()`.
2473
2474    `restore()` either assigns values immediately if variables to restore have
2475    been created already, or defers restoration until the variables are
2476    created. Dependencies added after this call will be matched if they have a
2477    corresponding object in the checkpoint (the restore request will queue in
2478    any trackable object waiting for the expected dependency to be added).
2479
2480    ```python
2481    checkpoint = tf.train.Checkpoint( ... )
2482    checkpoint.restore(path)
2483
2484    # You can additionally pass options to restore():
2485    options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
2486    checkpoint.restore(path, options=options)
2487    ```
2488
2489    To ensure that loading is complete and no more deferred restorations will
2490    take place, use the `assert_consumed()` method of the status object returned
2491    by `restore()`:
2492
2493    ```python
2494    checkpoint.restore(path, options=options).assert_consumed()
2495    ```
2496
2497    The assert will raise an error if any Python objects in the dependency graph
2498    were not found in the checkpoint, or if any checkpointed values do not have
2499    a matching Python object.
2500
2501    Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
2502    loaded using this method. Names are used to match variables. Re-encode
2503    name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible.
2504
2505    **Loading from SavedModel checkpoints**
2506
2507    To load values from a SavedModel, just pass the SavedModel directory
2508    to checkpoint.restore:
2509
2510    ```python
2511    model = tf.keras.Model(...)
2512    tf.saved_model.save(model, path)  # or model.save(path, save_format='tf')
2513
2514    checkpoint = tf.train.Checkpoint(model)
2515    checkpoint.restore(path).expect_partial()
2516    ```
2517
2518    This example calls `expect_partial()` on the loaded status, since
2519    SavedModels saved from Keras often generates extra keys in the checkpoint.
2520    Otherwise, the program prints a lot of warnings about unused keys at exit
2521    time.
2522
2523    Args:
2524      save_path: The path to the checkpoint, as returned by `save` or
2525        `tf.train.latest_checkpoint`. If the checkpoint was written by the
2526        name-based `tf.compat.v1.train.Saver`, names are used to match
2527        variables. This path may also be a SavedModel directory.
2528      options: Optional `tf.train.CheckpointOptions` object.
2529
2530    Returns:
2531      A load status object, which can be used to make assertions about the
2532      status of a checkpoint restoration.
2533
2534      The returned status object has the following methods:
2535
2536      * `assert_consumed()`:
2537          Raises an exception if any variables are unmatched: either
2538          checkpointed values which don't have a matching Python object or
2539          Python objects in the dependency graph with no values in the
2540          checkpoint. This method returns the status object, and so may be
2541          chained with other assertions.
2542
2543      * `assert_existing_objects_matched()`:
2544          Raises an exception if any existing Python objects in the dependency
2545          graph are unmatched. Unlike `assert_consumed`, this assertion will
2546          pass if values in the checkpoint have no corresponding Python
2547          objects. For example a `tf.keras.Layer` object which has not yet been
2548          built, and so has not created any variables, will pass this assertion
2549          but fail `assert_consumed`. Useful when loading part of a larger
2550          checkpoint into a new Python program, e.g. a training checkpoint with
2551          a `tf.compat.v1.train.Optimizer` was saved but only the state required
2552          for
2553          inference is being loaded. This method returns the status object, and
2554          so may be chained with other assertions.
2555
2556      * `assert_nontrivial_match()`: Asserts that something aside from the root
2557          object was matched. This is a very weak assertion, but is useful for
2558          sanity checking in library code where objects may exist in the
2559          checkpoint which haven't been created in Python and some Python
2560          objects may not have a checkpointed value.
2561
2562      * `expect_partial()`: Silence warnings about incomplete checkpoint
2563          restores. Warnings are otherwise printed for unused parts of the
2564          checkpoint file or object when the `Checkpoint` object is deleted
2565          (often at program shutdown).
2566
2567    Raises:
2568      NotFoundError: if the a checkpoint or SavedModel cannot be found at
2569        `save_path`.
2570    """
2571    orig_save_path = save_path
2572    if isinstance(save_path, os.PathLike):
2573      save_path = os.fspath(save_path)
2574
2575    if save_path is not None and gfile.IsDirectory(save_path) and (
2576        (gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or
2577         gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))):
2578      save_path = utils_impl.get_variables_path(save_path)
2579
2580    try:
2581      status = self.read(save_path, options=options)
2582      if context.executing_eagerly():
2583        context.async_wait()  # Ensure restore operations have completed.
2584    except errors_impl.NotFoundError as e:
2585      raise errors_impl.NotFoundError(
2586          None, None,
2587          f"Error when restoring from checkpoint or SavedModel at "
2588          f"{orig_save_path}: {e.message}"
2589          f"\nPlease double-check that the path is correct. You may be missing "
2590          "the checkpoint suffix (e.g. the '-1' in 'path/to/ckpt-1').")
2591    # Create the save counter now so it gets initialized with other variables
2592    # when graph building. Creating it earlier would lead to errors when using,
2593    # say, train.Saver() to save the model before initializing it.
2594    self._maybe_create_save_counter()
2595    if isinstance(status, NameBasedSaverStatus):
2596      status.add_to_optionally_restored(self.save_counter)
2597    return status
2598