• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""An object-local variable management scheme."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22
23import six
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_io_ops as io_ops
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.training.saving import saveable_object
34from tensorflow.python.util import tf_contextlib
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util.tf_export import tf_export
37
38# Key where the object graph proto is saved in a TensorBundle
39OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
40
41# A key indicating a variable's value in an object's checkpointed Tensors
42# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
43# the object has no dependencies, then its value may be restored on object
44# creation (avoiding double assignment when executing eagerly).
45VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
46OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
47
48
49@tf_export("__internal__.tracking.TrackableReference", v1=[])
50class TrackableReference(
51    collections.namedtuple("TrackableReference", ["name", "ref"])):
52  """A named reference to a trackable object for use with the `Trackable` class.
53
54  These references mark named `Trackable` dependencies of a `Trackable` object
55  and should be created when overriding `Trackable._checkpoint_dependencies`.
56
57  Attributes:
58    name: The local name for this dependency.
59    ref: The `Trackable` object being referenced.
60  """
61
62
63# TODO(bfontain):  Update once sharded initialization interface is finalized.
64ShardInfo = collections.namedtuple(
65    "CheckpointInitialValueShardInfo", ["shape", "offset"])
66
67
68@tf_export("__internal__.tracking.CheckpointInitialValueCallable", v1=[])
69class CheckpointInitialValueCallable(object):
70  """A callable object that returns a CheckpointInitialValue.
71
72  See CheckpointInitialValue for more information.
73  """
74
75  def __init__(self, checkpoint_position):
76    self._checkpoint_position = checkpoint_position
77
78  @property
79  def checkpoint_position(self):
80    return self._checkpoint_position
81
82  def __call__(self, shape=None, dtype=None, shard_info=None):
83    # Note that the signature here is for compatibility with normal callable
84    # initializers which take shape and dtype. Although dtype isn't used, it
85    # will get passed in by a functool.partial_wrapper in places like
86    # base_layer_utils.py's make_variable.
87    return CheckpointInitialValue(
88        self._checkpoint_position, shape, shard_info=shard_info)
89
90  @property
91  def restore_uid(self):
92    return self._checkpoint_position.restore_uid
93
94
95@tf_export("__internal__.tracking.CheckpointInitialValue", v1=[])
96class CheckpointInitialValue(ops.Tensor):
97  """Tensor wrapper for managing update UIDs in `Variables`.
98
99  When supplied as an initial value, objects of this type let a `Variable`
100  (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial
101  value came from. This allows deferred restorations to be sequenced in the
102  order the user specified them, and lets us fall back on assignment if an
103  initial value is not set (e.g. due to a custom getter interfering).
104
105  See comments in _add_variable_with_custom_getter for more information about
106  how `CheckpointInitialValue` is used.
107  """
108
109  def __init__(self, checkpoint_position, shape=None, shard_info=None):
110    if shard_info:
111      full_shape_str = " ".join("%d" % d for d in shape) + " "
112      slice_spec = ":".join(
113          "%d,%d" % (o, s) for o, s in zip(shard_info.offset, shard_info.shape))
114      shape_and_slice = full_shape_str + slice_spec
115    else:
116      shape_and_slice = ""
117    self.wrapped_value = checkpoint_position.value_tensors(
118        {VARIABLE_VALUE_KEY: shape_and_slice})[VARIABLE_VALUE_KEY]
119    self._checkpoint_position = checkpoint_position
120
121  def __getattr__(self, attr):
122    try:
123      return getattr(self.wrapped_value, attr)
124    except AttributeError:
125      return self.__getattribute__(attr)
126
127  @property
128  def checkpoint_position(self):
129    return self._checkpoint_position
130
131
132class NoRestoreSaveable(saveable_object.SaveableObject):
133  """Embeds a tensor in a checkpoint with no restore ops."""
134
135  def __init__(self, tensor, name, dtype=None, device=None):
136    spec = saveable_object.SaveSpec(
137        tensor, "", name, dtype=dtype, device=device)
138    super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
139
140  def restore(self, restored_tensors, restored_shapes):
141    return control_flow_ops.no_op()
142
143
144@six.add_metaclass(abc.ABCMeta)
145class PythonStateSaveable(saveable_object.SaveableObject):
146  """An interface for saving/restoring volatile Python state."""
147
148  @abc.abstractmethod
149  def feed_dict_additions(self):
150    """When running a graph, indicates fresh state to feed.
151
152    Returns:
153      A dictionary mapping `Tensor`s to current Python state.
154    """
155    pass
156
157  @abc.abstractmethod
158  def freeze(self):
159    """Create a new `SaveableObject` which freezes current state as a constant.
160
161    Used when executing eagerly to embed the current state as a constant, or
162    when creating a static tf.compat.v1.train.Saver with the frozen current
163    Python state.
164
165    Returns:
166      A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
167      no Python state associated with it).
168    """
169    pass
170
171
172class PythonStringStateSaveable(PythonStateSaveable):
173  """Saves Python state in a checkpoint."""
174
175  def __init__(self, name, state_callback, restore_callback=None):
176    """Configure saving.
177
178    Args:
179      name: The checkpoint key to write to.
180      state_callback: A function taking no arguments which returns a string.
181        This function is run every time a checkpoint is written.
182      restore_callback: A function taking a Python string, used to restore
183        state. Optional; defaults to doing nothing, in which case it is ignored
184        by status assertions such as assert_consumed().
185    """
186    self._has_trivial_state_callback = (restore_callback is None)
187
188    def _state_callback_wrapper():
189      with ops.init_scope():
190        return state_callback()
191
192    self._state_callback = _state_callback_wrapper
193    self._restore_callback = restore_callback
194    with ops.device("/cpu:0"):
195      self._save_string = constant_op.constant("", dtype=dtypes.string)
196    spec = saveable_object.SaveSpec(
197        self._save_string, "", name, dtype=dtypes.string)
198    super(PythonStringStateSaveable, self).__init__(self._save_string, [spec],
199                                                    name)
200
201  @property
202  def optional_restore(self):
203    """For values with no restore, relaxes assert_consumed()."""
204    return self._has_trivial_state_callback
205
206  def feed_dict_additions(self):
207    """When running a graph, indicates fresh state to feed."""
208    return {self._save_string: self._state_callback()}
209
210  def freeze(self):
211    """Create a frozen `SaveableObject` which saves the current state."""
212
213    def _constant_state():
214      return constant_op.constant(self._state_callback(), dtype=dtypes.string)
215
216    return NoRestoreSaveable(
217        tensor=_constant_state,
218        dtype=dtypes.string,
219        name=self.name,
220        device="cpu:0")
221
222  def python_restore(self, restored_strings):
223    """Called to restore Python state."""
224    if self._restore_callback:
225      restored, = restored_strings
226      self._restore_callback(restored)
227
228  def restore(self, restored_tensors, restored_shapes):
229    """Called to restore TensorFlow state (nothing to do)."""
230    return control_flow_ops.no_op()
231
232
233class CheckpointPosition(object):
234  """Indicates a position within a `_CheckpointRestoreCoordinator`."""
235
236  __slots__ = ["_checkpoint", "_proto_id"]
237
238  def __init__(self, checkpoint, proto_id):
239    """Specify an object within a checkpoint.
240
241    Args:
242      checkpoint: A _CheckpointRestoreCoordinator object.
243      proto_id: The index of this object in TrackableObjectGraph.nodes.
244    """
245    self._checkpoint = checkpoint
246    self._proto_id = proto_id
247
248  def restore(self, trackable):
249    """Restore this value into `trackable`."""
250    with ops.init_scope():
251      if self.bind_object(trackable):
252        # This object's correspondence with a checkpointed object is new, so
253        # process deferred restorations for it and its dependencies.
254        restore_ops = trackable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
255        if restore_ops:
256          self._checkpoint.new_restore_ops(restore_ops)
257
258  def bind_object(self, trackable):
259    """Set a checkpoint<->object correspondence and process slot variables.
260
261    Args:
262      trackable: The object to record a correspondence for.
263
264    Returns:
265      True if this is a new assignment, False if this object has already been
266      mapped to a checkpointed `Object` proto.
267    Raises:
268      AssertionError: If another object is already bound to the `Object` proto.
269    """
270    checkpoint = self.checkpoint
271    checkpoint.all_python_objects.add(trackable)
272    current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
273    checkpoint.matched_proto_ids.add(self._proto_id)
274    if current_assignment is None:
275      checkpoint.object_by_proto_id[self._proto_id] = trackable
276      for deferred_slot_restoration in (
277          checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
278        trackable._create_or_restore_slot_variable(  # pylint: disable=protected-access
279            slot_variable_position=CheckpointPosition(
280                checkpoint=checkpoint,
281                proto_id=deferred_slot_restoration.slot_variable_id),
282            variable=deferred_slot_restoration.original_variable,
283            slot_name=deferred_slot_restoration.slot_name)
284      for slot_restoration in checkpoint.slot_restorations.pop(
285          self._proto_id, ()):
286        optimizer_object = checkpoint.object_by_proto_id.get(
287            slot_restoration.optimizer_id, None)
288        if optimizer_object is None:
289          # The optimizer has not yet been created or tracked. Record in the
290          # checkpoint that the slot variables need to be restored when it is.
291          checkpoint.deferred_slot_restorations.setdefault(
292              slot_restoration.optimizer_id, []).append(
293                  _DeferredSlotVariableRestoration(
294                      original_variable=trackable,
295                      slot_variable_id=slot_restoration.slot_variable_id,
296                      slot_name=slot_restoration.slot_name))
297
298        # `optimizer_object` can be a `Checkpoint` when user only needs the
299        # attributes the optimizer holds, such as `iterations`. In those cases,
300        # it would not have the optimizer's `_create_or_restore_slot_variable`
301        # method.
302        elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
303          optimizer_object._create_or_restore_slot_variable(  # pylint: disable=protected-access
304              slot_variable_position=CheckpointPosition(
305                  checkpoint=checkpoint,
306                  proto_id=slot_restoration.slot_variable_id),
307              variable=trackable,
308              slot_name=slot_restoration.slot_name)
309      return True  # New assignment
310    else:
311      # The object was already mapped for this checkpoint load, which means
312      # we don't need to do anything besides check that the mapping is
313      # consistent (if the dependency DAG is not a tree then there are
314      # multiple paths to the same object).
315      if current_assignment is not trackable:
316        logging.warning((
317            "Inconsistent references when loading the checkpoint into this "
318            "object graph. Either the Trackable object references in the "
319            "Python program have changed in an incompatible way, or the "
320            "checkpoint was generated in an incompatible program.\n\nTwo "
321            "checkpoint references resolved to different objects (%s and %s)."),
322                        current_assignment, trackable)
323      return False  # Not a new assignment
324
325  def is_simple_variable(self):
326    """Determine whether this value is restorable with a Tensor initializer."""
327    attributes = self.object_proto.attributes
328    return (len(attributes) == 1 and
329            attributes[0].name == VARIABLE_VALUE_KEY and
330            not self.object_proto.children)
331
332  def value_tensors(self, shape_and_slices=None):
333    """Create value `Tensor`s for this object's attributes.
334
335    Does not require that the Python object has been created. Used for
336    restore-on-create when executing eagerly.
337
338    Args:
339      shape_and_slices: A dict mapping from object attribute names to a shape
340        and slice string that will be passed to a RestoreV2 op. If the dict is
341        None or if an object attribute is not in the dict, the full tensor will
342        be restored.
343
344    Returns:
345      A dictionary mapping from object attribute names to `Tensor`s.
346    """
347    value_tensors = {}
348    for serialized_tensor in self.object_proto.attributes:
349      checkpoint_key = serialized_tensor.checkpoint_key
350      dtype = self._checkpoint.dtype_map[checkpoint_key]
351      base_type = dtype.base_dtype
352      io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
353      with ops.init_scope():
354        with ops.device(io_device):
355          # Run the restore itself on the io_device(CPU or specified).
356          if (shape_and_slices is not None and
357              serialized_tensor.name in shape_and_slices):
358            shape_and_slice = shape_and_slices[serialized_tensor.name]
359          else:
360            shape_and_slice = ""
361          value, = io_ops.restore_v2(
362              prefix=self._checkpoint.save_path_tensor,
363              tensor_names=[checkpoint_key],
364              shape_and_slices=[shape_and_slice],
365              dtypes=[base_type],
366              name="%s_checkpoint_read" % (serialized_tensor.name,))
367        # Copy the value to the current device if necessary.
368        value_tensors[serialized_tensor.name] = array_ops.identity(value)
369    return value_tensors
370
371  def gather_ops_or_named_saveables(self):
372    """Looks up or creates SaveableObjects which don't have cached ops."""
373    saveables = self.trackable._gather_saveables_for_checkpoint()  # pylint: disable=protected-access
374    # Name saveables based on the name this object had when it was checkpointed.
375    named_saveables = {}
376    python_saveables = []
377    existing_restore_ops = []
378    for serialized_tensor in self.object_proto.attributes:
379      if context.executing_eagerly():
380        existing_op = None
381      else:
382        existing_op = self._checkpoint.restore_ops_by_name.get(
383            serialized_tensor.checkpoint_key, None)
384      if existing_op is not None:
385        existing_restore_ops.append(existing_op)
386        continue
387
388      # Only if we don't have cached ops for this SaveableObject, we'll see if
389      # the SaveableObject itself has been cached. If not, we'll make it, and
390      # either way we'll extract new ops from it (or if it has Python state to
391      # restore, we'll run that).
392      saveables_cache = self._checkpoint.graph_view.saveables_cache
393      if saveables_cache is None:
394        # No SaveableObject caching when executing eagerly.
395        saveable = None
396      else:
397        # If we've already created and cached a SaveableObject for this
398        # attribute, we can re-use it to avoid re-creating some ops when graph
399        # building.
400        saveable_list = saveables_cache.get(self.trackable,
401                                            {}).get(serialized_tensor.name,
402                                                    (None,))
403        if len(saveable_list) == 1:
404          # Almost every attribute will have exactly one SaveableObject.
405          saveable, = saveable_list
406        else:
407          # Don't use cached SaveableObjects for partitioned variables, which is
408          # the only case where we'd have a list of SaveableObjects. Op caching
409          # will catch them.
410          saveable = None
411      if saveable is not None:
412        # The name of this attribute has changed, so we need to re-generate
413        # the SaveableObject.
414        if serialized_tensor.checkpoint_key not in saveable.name:
415          saveable = None
416          del saveables_cache[self.trackable]
417      if saveable is None:
418        # If there was no cached SaveableObject, we should check if the Python
419        # object has the attribute.
420        saveable_factory = saveables.get(serialized_tensor.name, None)
421        if saveable_factory is None:
422          # Purposefully does not throw an exception if attributes have been
423          # added or deleted. Stores unused attributes so an exception can be
424          # raised if the user decides to check that everything in the
425          # checkpoint was loaded.
426          if not serialized_tensor.optional_restore:
427            self._checkpoint.unused_attributes.setdefault(
428                self._proto_id, []).append(serialized_tensor.name)
429          continue
430        if callable(saveable_factory):
431          saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
432        else:
433          saveable = saveable_factory
434        if saveables_cache is not None:
435          saveables_cache.setdefault(self.trackable,
436                                     {})[serialized_tensor.name] = [saveable]
437      if isinstance(saveable, PythonStateSaveable):
438        python_saveables.append(saveable)
439      else:
440        named_saveables[serialized_tensor.checkpoint_key] = saveable
441    return existing_restore_ops, named_saveables, python_saveables
442
443  def restore_ops(self):
444    """Create or fetch restore ops for this object's attributes.
445
446    Requires that the `Trackable` Python object has been bound to an object
447    ID in the checkpoint.
448
449    Returns:
450      A list of operations when graph building, or an empty list when executing
451      eagerly.
452    """
453    (restore_ops, tensor_saveables,
454     python_saveables) = self.gather_ops_or_named_saveables()
455    restore_ops.extend(
456        self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
457    return restore_ops
458
459  @property
460  def checkpoint(self):
461    return self._checkpoint
462
463  @property
464  def trackable(self):
465    return self._checkpoint.object_by_proto_id[self._proto_id]
466
467  @property
468  def object_proto(self):
469    return self._checkpoint.object_graph_proto.nodes[self._proto_id]
470
471  @property
472  def restore_uid(self):
473    return self._checkpoint.restore_uid
474
475  def __repr__(self):
476    return repr(self.object_proto)
477
478  def value_shape(self):
479    """The shape of the VARIABLE_VALUE tensor.
480
481    Returns:
482      If found a TensorShape object, otherwise None.
483    """
484    for serialized_tensor in self.object_proto.attributes:
485      if serialized_tensor.name == VARIABLE_VALUE_KEY:
486        return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
487    return None
488
489
490_DeferredSlotVariableRestoration = collections.namedtuple(
491    "_DeferredSlotVariableRestoration", [
492        "original_variable",
493        "slot_variable_id",
494        "slot_name",
495    ])
496
497_SlotVariableRestoration = collections.namedtuple(
498    "_SlotVariableRestoration",
499    [
500        # The checkpoint proto id of the optimizer object.
501        "optimizer_id",
502        # The checkpoint proto id of the slot variable.
503        "slot_variable_id",
504        "slot_name",
505    ])
506
507
508@tf_export("__internal__.tracking.no_automatic_dependency_tracking", v1=[])
509def no_automatic_dependency_tracking(method):
510  """Disables automatic dependency tracking on attribute assignment.
511
512  Use to decorate any method of a Trackable object. Attribute assignment in
513  that method will not add dependencies (also respected in Model). Harmless if
514  used in a class which does not do automatic dependency tracking (which means
515  it's safe to use in base classes which may have subclasses which also inherit
516  from Trackable).
517
518  Args:
519    method: The method to decorate.
520
521  Returns:
522    A decorated method which sets and un-sets automatic dependency tracking for
523    the object the method is called on (not thread safe).
524  """
525
526  def _method_wrapper(self, *args, **kwargs):
527    previous_value = getattr(self, "_self_setattr_tracking", True)
528    self._self_setattr_tracking = False  # pylint: disable=protected-access
529    try:
530      result = method(self, *args, **kwargs)
531    finally:
532      self._self_setattr_tracking = previous_value  # pylint: disable=protected-access
533    return result
534
535  return tf_decorator.make_decorator(
536      target=method, decorator_func=_method_wrapper)
537
538
539@tf_contextlib.contextmanager
540def no_manual_dependency_tracking_scope(obj):
541  """A context that disables manual dependency tracking for the given `obj`.
542
543  Sometimes library methods might track objects on their own and we might want
544  to disable that and do the tracking on our own. One can then use this context
545  manager to disable the tracking the library method does and do your own
546  tracking.
547
548  For example:
549
550  class TestLayer(tf.keras.Layer):
551    def build():
552      with no_manual_dependency_tracking_scope(self):
553        var = self.add_variable("name1")  # Creates a var and doesn't track it
554      self._track_trackable("name2", var)  # We track variable with name `name2`
555
556  Args:
557    obj: A trackable object.
558
559  Yields:
560    a scope in which the object doesn't track dependencies manually.
561  """
562  # pylint: disable=protected-access
563  previous_value = getattr(obj, "_manual_tracking", True)
564  obj._manual_tracking = False
565  try:
566    yield
567  finally:
568    obj._manual_tracking = previous_value
569
570
571@tf_contextlib.contextmanager
572def no_automatic_dependency_tracking_scope(obj):
573  """A context that disables automatic dependency tracking when assigning attrs.
574
575  Objects that inherit from Autotrackable automatically creates dependencies
576  to trackable objects through attribute assignments, and wraps data structures
577  (lists or dicts) with trackable classes. This scope may be used to temporarily
578  disable this behavior. This works similar to the decorator
579  `no_automatic_dependency_tracking`.
580
581  Example usage:
582  ```
583  model = tf.keras.Model()
584  model.arr1 = []  # Creates a ListWrapper object
585  with no_automatic_dependency_tracking_scope(model):
586    model.arr2 = []  # Creates a regular, untracked python list
587  ```
588
589  Args:
590    obj: A trackable object.
591
592  Yields:
593    a scope in which the object doesn't track dependencies.
594  """
595  previous_value = getattr(obj, "_setattr_tracking", True)
596  obj._setattr_tracking = False  # pylint: disable=protected-access
597  try:
598    yield
599  finally:
600    obj._setattr_tracking = previous_value  # pylint: disable=protected-access
601
602
603@tf_export("__internal__.tracking.Trackable", v1=[])
604class Trackable(object):
605  """Base class for `Trackable` objects without automatic dependencies.
606
607  This class has no __setattr__ override for performance reasons. Dependencies
608  must be added explicitly. Unless attribute assignment is performance-critical,
609  use `AutoTrackable` instead. Use `Trackable` for `isinstance`
610  checks.
611  """
612
613  # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with
614  # _self_. We have some properties to forward semi-public attributes to their
615  # _self_ equivalents.
616
617  @property
618  def _setattr_tracking(self):
619    if not hasattr(self, "_self_setattr_tracking"):
620      self._self_setattr_tracking = True
621    return self._self_setattr_tracking
622
623  @_setattr_tracking.setter
624  def _setattr_tracking(self, value):
625    self._self_setattr_tracking = value
626
627  @property
628  def _update_uid(self):
629    return self._self_update_uid
630
631  @_update_uid.setter
632  def _update_uid(self, value):
633    self._self_update_uid = value
634
635  @property
636  def _unconditional_checkpoint_dependencies(self):
637    return self._self_unconditional_checkpoint_dependencies
638
639  @property
640  def _unconditional_dependency_names(self):
641    return self._self_unconditional_dependency_names
642
643  @property
644  def _name_based_restores(self):
645    return self._self_name_based_restores
646
647  # Trackable does not do automatic dependency tracking, but uses the
648  # no_automatic_dependency_tracking decorator so it can avoid adding
649  # dependencies if a subclass is Trackable / inherits from Model (both of
650  # which have __setattr__ overrides).
651  @no_automatic_dependency_tracking
652  def _maybe_initialize_trackable(self):
653    """Initialize dependency management.
654
655    Not __init__, since most objects will forget to call it.
656    """
657    if hasattr(self, "_self_unconditional_checkpoint_dependencies"):
658      # __init__ already called. This check means that we don't need
659      # Trackable.__init__() in the constructor of every TensorFlow object.
660      return
661    # A list of TrackableReference objects. Some classes implementing
662    # `Trackable`, notably `Optimizer`s, may override the
663    # _checkpoint_dependencies property with conditional dependencies
664    # (e.g. based on the current graph when saving).
665    self._self_unconditional_checkpoint_dependencies = []
666    # Maps names -> Trackable objects
667    self._self_unconditional_dependency_names = {}
668    # Restorations for other Trackable objects on which this object may
669    # eventually depend. Maps local name -> CheckpointPosition list. Optimizers
670    # tack on conditional dependencies, and so need separate management of
671    # deferred dependencies too.
672    self._self_unconditional_deferred_dependencies = {}
673    # The UID of the highest assignment to this object. Used to ensure that the
674    # last requested assignment determines the final value of an object.
675    if hasattr(self, "_self_update_uid"):
676      raise AssertionError(
677          "Internal error: the object had an update UID set before its "
678          "initialization code was run.")
679    self._self_update_uid = -1
680    # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator
681    # instances, which should be checked when creating variables or other
682    # saveables. These are passed on recursively to all dependencies, since
683    # unlike object-based checkpoint restores we don't know which subgraph is
684    # being restored in advance. This mechanism is only necessary for
685    # restore-on-create when executing eagerly, and so is unused when graph
686    # building.
687    self._self_name_based_restores = set()
688
689    # Dictionary of SaveableObjects factories. This dictionary is defined when
690    # the object is loaded from the SavedModel. When writing a custom class,
691    # prefer overriding "_gather_saveables_from_checkpoint" to using this
692    # attribute.
693    self._self_saveable_object_factories = {}
694
695  @property
696  def _object_identifier(self):
697    """String used to identify this object in a SavedModel.
698
699    Generally, the object identifier is constant across objects of the same
700    class, while the metadata field is used for instance-specific data.
701
702    Returns:
703      String object identifier.
704    """
705    return "_generic_user_object"
706
707  def _no_dependency(self, value):
708    """If automatic dependency tracking is enabled, ignores `value`."""
709    return value
710
711  def _name_based_attribute_restore(self, checkpoint):
712    """Restore the object's attributes from a name-based checkpoint."""
713    self._self_name_based_restores.add(checkpoint)
714    if self._self_update_uid < checkpoint.restore_uid:
715      checkpoint.eager_restore(self)
716      self._self_update_uid = checkpoint.restore_uid
717
718  @property
719  def _checkpoint_dependencies(self):
720    """All dependencies of this object.
721
722    May be overridden to include conditional dependencies.
723
724    Returns:
725      A list of `TrackableReference` objects indicating named
726      `Trackable` dependencies which should be saved along with this
727      object.
728    """
729    return self._self_unconditional_checkpoint_dependencies
730
731  @property
732  def _deferred_dependencies(self):
733    """A dictionary with deferred dependencies.
734
735    Stores restorations for other Trackable objects on which this object
736    may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
737    conditional dependencies based the current graph, and so need separate
738    management of deferred dependencies too).
739
740    Returns:
741      A dictionary mapping from local name to a list of CheckpointPosition
742      objects.
743    """
744    return self._self_unconditional_deferred_dependencies
745
746  def _lookup_dependency(self, name):
747    """Look up a dependency by name.
748
749    May be overridden to include conditional dependencies.
750
751    Args:
752      name: The local name of the dependency.
753
754    Returns:
755      A `Trackable` object, or `None` if no dependency by this name was
756      found.
757    """
758    return self._self_unconditional_dependency_names.get(name, None)
759
760  def _add_variable_with_custom_getter(self,
761                                       name,
762                                       shape=None,
763                                       dtype=dtypes.float32,
764                                       initializer=None,
765                                       getter=None,
766                                       overwrite=False,
767                                       **kwargs_for_getter):
768    """Restore-on-create for a variable be saved with this `Trackable`.
769
770    If the user has requested that this object or another `Trackable` which
771    depends on this object be restored from a checkpoint (deferred loading
772    before variable object creation), `initializer` may be ignored and the value
773    from the checkpoint used instead.
774
775    Args:
776      name: A name for the variable. Must be unique within this object.
777      shape: The shape of the variable.
778      dtype: The data type of the variable.
779      initializer: The initializer to use. Ignored if there is a deferred
780        restoration left over from a call to
781        `_restore_from_checkpoint_position`.
782      getter: The getter to wrap which actually fetches the variable.
783      overwrite: If True, disables unique name and type checks.
784      **kwargs_for_getter: Passed to the getter.
785
786    Returns:
787      The new variable object.
788
789    Raises:
790      ValueError: If the variable name is not unique.
791    """
792    self._maybe_initialize_trackable()
793    with ops.init_scope():
794      if context.executing_eagerly():
795        # If this is a variable with a single Tensor stored in the checkpoint,
796        # we can set that value as an initializer rather than initializing and
797        # then assigning (when executing eagerly). This call returns None if
798        # there is nothing to restore.
799        checkpoint_initializer = self._preload_simple_restoration(
800            name=name)
801      else:
802        checkpoint_initializer = None
803      if (checkpoint_initializer is not None and
804          not (isinstance(initializer, CheckpointInitialValueCallable) and
805               (initializer.restore_uid > checkpoint_initializer.restore_uid))):
806        # If multiple Trackable objects are "creating" the same variable
807        # via the magic of custom getters, the one with the highest restore UID
808        # (the one called last) has to make the final initializer. If another
809        # custom getter interrupts this process by overwriting the initializer,
810        # then we'll catch that when we call _track_trackable. So this is
811        # "best effort" to set the initializer with the highest restore UID.
812        initializer = checkpoint_initializer
813    new_variable = getter(
814        name=name,
815        shape=shape,
816        dtype=dtype,
817        initializer=initializer,
818        **kwargs_for_getter)
819
820    # If we set an initializer and the variable processed it, tracking will not
821    # assign again. It will add this variable to our dependencies, and if there
822    # is a non-trivial restoration queued, it will handle that. This also
823    # handles slot variables.
824    if not overwrite or isinstance(new_variable, Trackable):
825      return self._track_trackable(new_variable, name=name, overwrite=overwrite)
826    else:
827      # TODO(allenl): Some variable types are not yet supported. Remove this
828      # fallback once all get_variable() return types are Trackable.
829      return new_variable
830
831  def _preload_simple_restoration(self, name):
832    """Return a dependency's value for restore-on-create.
833
834    Note the restoration is not deleted; if for some reason preload is called
835    and then not assigned to the variable (for example because a custom getter
836    overrides the initializer), the assignment will still happen once the
837    variable is tracked (determined based on checkpoint.restore_uid).
838
839    Args:
840      name: The object-local name of the dependency holding the variable's
841        value.
842
843    Returns:
844      An callable for use as a variable's initializer/initial_value, or None if
845      one should not be set (either because there was no variable with this name
846      in the checkpoint or because it needs more complex deserialization). Any
847      non-trivial deserialization will happen when the variable object is
848      tracked.
849    """
850    deferred_dependencies_list = self._deferred_dependencies.get(name, ())
851    if not deferred_dependencies_list:
852      # Nothing to do; we don't have a restore for this dependency queued up.
853      return
854    for checkpoint_position in deferred_dependencies_list:
855      if not checkpoint_position.is_simple_variable():
856        # If _any_ pending restoration is too complicated to fit in an
857        # initializer (because it has dependencies, or because there are
858        # multiple Tensors to restore), bail and let the general tracking code
859        # handle it.
860        return None
861    checkpoint_position = max(
862        deferred_dependencies_list,
863        key=lambda restore: restore.checkpoint.restore_uid)
864    return CheckpointInitialValueCallable(
865        checkpoint_position=checkpoint_position)
866
867  def _track_trackable(self, trackable, name, overwrite=False):
868    """Declare a dependency on another `Trackable` object.
869
870    Indicates that checkpoints for this object should include variables from
871    `trackable`.
872
873    Variables in a checkpoint are mapped to `Trackable`s based on the names
874    provided when the checkpoint was written. To avoid breaking existing
875    checkpoints when modifying a class, neither variable names nor dependency
876    names (the names passed to `_track_trackable`) may change.
877
878    Args:
879      trackable: A `Trackable` which this object depends on.
880      name: A local name for `trackable`, used for loading checkpoints into the
881        correct objects.
882      overwrite: Boolean, whether silently replacing dependencies is OK. Used
883        for __setattr__, where throwing an error on attribute reassignment would
884        be inappropriate.
885
886    Returns:
887      `trackable`, for convenience when declaring a dependency and
888      assigning to a member variable in one statement.
889
890    Raises:
891      TypeError: If `trackable` does not inherit from `Trackable`.
892      ValueError: If another object is already tracked by this name.
893    """
894    self._maybe_initialize_trackable()
895    if not isinstance(trackable, Trackable):
896      raise TypeError(("Trackable._track_trackable() passed type %s, not a "
897                       "Trackable.") % (type(trackable),))
898    if not getattr(self, "_manual_tracking", True):
899      return trackable
900    new_reference = TrackableReference(name=name, ref=trackable)
901    current_object = self._lookup_dependency(name)
902    if (current_object is not None and current_object is not trackable):
903      if not overwrite:
904        raise ValueError(
905            ("Called Trackable._track_trackable() with name='%s', "
906             "but a Trackable with this name is already declared as a "
907             "dependency. Names must be unique (or overwrite=True).") % (name,))
908      # This is a weird thing to do, but we're not going to stop people from
909      # using __setattr__.
910      for index, (old_name, _) in enumerate(
911          self._self_unconditional_checkpoint_dependencies):
912        if name == old_name:
913          self._self_unconditional_checkpoint_dependencies[
914              index] = new_reference
915    elif current_object is None:
916      self._self_unconditional_checkpoint_dependencies.append(new_reference)
917      self._handle_deferred_dependencies(name=name, trackable=trackable)
918    self._self_unconditional_dependency_names[name] = trackable
919    return trackable
920
921  def _handle_deferred_dependencies(self, name, trackable):
922    """Pop and load any deferred checkpoint restores into `trackable`.
923
924    This method does not add a new dependency on `trackable`, but it does
925    check if any outstanding/deferred dependencies have been queued waiting for
926    this dependency to be added (matched based on `name`). If so,
927    `trackable` and its dependencies are restored. The restorations are
928    considered fulfilled and so are deleted.
929
930    `_track_trackable` is more appropriate for adding a
931    normal/unconditional dependency, and includes handling for deferred
932    restorations. This method allows objects such as `Optimizer` to use the same
933    restoration logic while managing conditional dependencies themselves, by
934    overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the
935    object's dependencies based on the context it is saved/restored in (a single
936    optimizer instance can have state associated with multiple graphs).
937
938    Args:
939      name: The name of the dependency within this object (`self`), used to
940        match `trackable` with values saved in a checkpoint.
941      trackable: The Trackable object to restore (inheriting from `Trackable`).
942    """
943    self._maybe_initialize_trackable()
944    trackable._maybe_initialize_trackable()  # pylint: disable=protected-access
945    deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
946    for checkpoint_position in sorted(
947        deferred_dependencies_list,
948        key=lambda restore: restore.checkpoint.restore_uid,
949        reverse=True):
950      checkpoint_position.restore(trackable)
951
952    # Pass on any name-based restores queued in this object.
953    for name_based_restore in sorted(
954        self._self_name_based_restores,
955        key=lambda checkpoint: checkpoint.restore_uid,
956        reverse=True):
957      trackable._name_based_attribute_restore(name_based_restore)  # pylint: disable=protected-access
958
959  def _restore_from_checkpoint_position(self, checkpoint_position):
960    """Restore this object and its dependencies (may be deferred)."""
961    # Attempt a breadth-first traversal, since presumably the user has more
962    # control over shorter paths. If we don't have all of the dependencies at
963    # this point, the end result is not breadth-first (since other deferred
964    # traversals will happen later).
965    visit_queue = collections.deque([checkpoint_position])
966    restore_ops = []
967    tensor_saveables = {}
968    python_saveables = []
969    while visit_queue:
970      current_position = visit_queue.popleft()
971      new_restore_ops, new_tensor_saveables, new_python_saveables = (
972          current_position.trackable  # pylint: disable=protected-access
973          ._single_restoration_from_checkpoint_position(
974              checkpoint_position=current_position,
975              visit_queue=visit_queue))
976      restore_ops.extend(new_restore_ops)
977      tensor_saveables.update(new_tensor_saveables)
978      python_saveables.extend(new_python_saveables)
979    restore_ops.extend(
980        current_position.checkpoint.restore_saveables(
981            tensor_saveables, python_saveables))
982    return restore_ops
983
984  def _single_restoration_from_checkpoint_position(self, checkpoint_position,
985                                                   visit_queue):
986    """Restore this object, and either queue its dependencies or defer them."""
987    self._maybe_initialize_trackable()
988    checkpoint = checkpoint_position.checkpoint
989    # If the UID of this restore is lower than our current update UID, we don't
990    # need to actually restore the object. However, we should pass the
991    # restoration on to our dependencies.
992    if checkpoint.restore_uid > self._self_update_uid:
993      restore_ops, tensor_saveables, python_saveables = (
994          checkpoint_position.gather_ops_or_named_saveables())
995      self._self_update_uid = checkpoint.restore_uid
996    else:
997      restore_ops = ()
998      tensor_saveables = {}
999      python_saveables = ()
1000    for child in checkpoint_position.object_proto.children:
1001      child_position = CheckpointPosition(
1002          checkpoint=checkpoint, proto_id=child.node_id)
1003      local_object = self._lookup_dependency(child.local_name)
1004      if local_object is None:
1005        # We don't yet have a dependency registered with this name. Save it
1006        # in case we do.
1007        self._deferred_dependencies.setdefault(child.local_name,
1008                                               []).append(child_position)
1009      else:
1010        if child_position.bind_object(trackable=local_object):
1011          # This object's correspondence is new, so dependencies need to be
1012          # visited. Delay doing it so that we get a breadth-first dependency
1013          # resolution order (shallowest paths first). The caller is responsible
1014          # for emptying visit_queue.
1015          visit_queue.append(child_position)
1016    return restore_ops, tensor_saveables, python_saveables
1017
1018  def _gather_saveables_for_checkpoint(self):
1019    """Returns a dictionary of values to checkpoint with this object.
1020
1021    Keys in the returned dictionary are local to this object and in a separate
1022    namespace from dependencies. Values may either be `SaveableObject` factories
1023    or variables easily converted to `SaveableObject`s (as in
1024    `tf.compat.v1.train.Saver`'s
1025    `var_list` constructor argument).
1026
1027    `SaveableObjects` have a name set, which Trackable needs to generate
1028    itself. So rather than returning `SaveableObjects` directly, this method
1029    should return a dictionary of callables which take `name` arguments and
1030    return `SaveableObjects` with that name.
1031
1032    If this object may also be passed to the global-name-based
1033    `tf.compat.v1.train.Saver`,
1034    the returned callables should have a default value for their name argument
1035    (i.e. be callable with no arguments).
1036
1037    Returned values must be saved only by this object; if any value may be
1038    shared, it should instead be a dependency. For example, variable objects
1039    save their own values with the key `VARIABLE_VALUE_KEY`, but objects which
1040    reference variables simply add a dependency.
1041
1042    Returns:
1043      The dictionary mapping attribute names to `SaveableObject` factories
1044      described above. For example:
1045      {VARIABLE_VALUE_KEY:
1046       lambda name="global_name_for_this_object":
1047       SaveableObject(name=name, ...)}
1048    """
1049    return self._self_saveable_object_factories
1050
1051  def _list_extra_dependencies_for_serialization(self, serialization_cache):
1052    """Lists extra dependencies to serialize.
1053
1054    Internal sub-classes can override this method to return extra dependencies
1055    that should be saved with the object during SavedModel serialization. For
1056    example, this is used to save `trainable_variables` in Keras models. The
1057    python property `trainable_variables` contains logic to iterate through the
1058    weights from the model and its sublayers. The serialized Keras model saves
1059    `trainable_weights` as a trackable list of variables.
1060
1061    PLEASE NOTE when overriding this method:
1062      1. This function may only generate new trackable objects the first time it
1063         is called.
1064      2. The returned dictionary must not have naming conflicts with
1065         dependencies tracked by the root. In other words, if the root is
1066         tracking `object_1` with name 'x', and this functions returns
1067         `{'x': object_2}`, an error is raised when saving.
1068
1069    Args:
1070      serialization_cache: A dictionary shared between all objects in the same
1071        object graph. This object is passed to both
1072        `_list_extra_dependencies_for_serialization` and
1073        `_list_functions_for_serialization`.
1074
1075    Returns:
1076      A dictionary mapping attribute names to trackable objects.
1077    """
1078    del serialization_cache
1079    return dict()
1080
1081  def _list_functions_for_serialization(self, serialization_cache):
1082    """Lists the functions of this trackable to serialize.
1083
1084    Internal sub-classes can override this with specific logic. E.g.
1085    `AutoTrackable` provides an implementation that returns the `attr`
1086    that return functions.
1087
1088    Args:
1089      serialization_cache: Dictionary passed to all objects in the same object
1090        graph during serialization.
1091
1092    Returns:
1093        A dictionary mapping attribute names to `Function` or
1094        `ConcreteFunction`.
1095    """
1096    del serialization_cache
1097    return dict()
1098
1099  def _map_resources(self, save_options):  # pylint: disable=unused-argument
1100    """Makes new resource handle ops corresponding to existing resource tensors.
1101
1102    Internal sub-classes can override this to inform model saving how to add new
1103    resource handle ops to the main GraphDef of a SavedModel (TF 1.x style
1104    graph), which allows session based APIs (e.g, C++ loader API) to interact
1105    with resources owned by this object.
1106
1107    Args:
1108      save_options: A tf.saved_model.SaveOptions instance.
1109
1110    Returns:
1111      A tuple of (object_map, resource_map):
1112        object_map: A dictionary mapping from objects that hold existing
1113          resource tensors to replacement objects created to hold the new
1114          resource tensors.
1115        resource_map: A dictionary mapping from existing resource tensors to
1116          newly created resource tensors.
1117    """
1118    return {}, {}
1119