• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""An object-local variable management scheme."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22
23import six
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_io_ops as io_ops
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.training.saving import saveable_object
34from tensorflow.python.util import tf_contextlib
35from tensorflow.python.util import tf_decorator
36
37# Key where the object graph proto is saved in a TensorBundle
38OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
39
40# A key indicating a variable's value in an object's checkpointed Tensors
41# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
42# the object has no dependencies, then its value may be restored on object
43# creation (avoiding double assignment when executing eagerly).
44VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
45OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
46
47TrackableReference = collections.namedtuple(
48    "TrackableReference",
49    [
50        # The local name for this dependency.
51        "name",
52        # The Trackable object being referenced.
53        "ref"
54    ])
55
56
57class CheckpointInitialValue(ops.Tensor):
58  """Tensor wrapper for managing update UIDs in `Variables`.
59
60  When supplied as an initial value, objects of this type let a `Variable`
61  (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial
62  value came from. This allows deferred restorations to be sequenced in the
63  order the user specified them, and lets us fall back on assignment if an
64  initial value is not set (e.g. due to a custom getter interfering).
65
66  See comments in _add_variable_with_custom_getter for more information about
67  how `CheckpointInitialValue` is used.
68  """
69
70  def __init__(self, checkpoint_position, shape=None):
71    self.wrapped_value = checkpoint_position.value_tensors()[VARIABLE_VALUE_KEY]
72    if shape:
73      # We need to set the static shape information on the initializer if
74      # possible so we don't get a variable with an unknown shape.
75      self.wrapped_value.set_shape(shape)
76    self._checkpoint_position = checkpoint_position
77
78  def __getattr__(self, attr):
79    try:
80      return getattr(self.wrapped_value, attr)
81    except AttributeError:
82      return self.__getattribute__(attr)
83
84  @property
85  def checkpoint_position(self):
86    return self._checkpoint_position
87
88
89class NoRestoreSaveable(saveable_object.SaveableObject):
90  """Embeds a tensor in a checkpoint with no restore ops."""
91
92  def __init__(self, tensor, name, dtype=None, device=None):
93    spec = saveable_object.SaveSpec(
94        tensor, "", name, dtype=dtype, device=device)
95    super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
96
97  def restore(self, restored_tensors, restored_shapes):
98    return control_flow_ops.no_op()
99
100
101@six.add_metaclass(abc.ABCMeta)
102class PythonStateSaveable(saveable_object.SaveableObject):
103  """An interface for saving/restoring volatile Python state."""
104
105  @abc.abstractmethod
106  def feed_dict_additions(self):
107    """When running a graph, indicates fresh state to feed.
108
109    Returns:
110      A dictionary mapping `Tensor`s to current Python state.
111    """
112    pass
113
114  @abc.abstractmethod
115  def freeze(self):
116    """Create a new `SaveableObject` which freezes current state as a constant.
117
118    Used when executing eagerly to embed the current state as a constant, or
119    when creating a static tf.compat.v1.train.Saver with the frozen current
120    Python state.
121
122    Returns:
123      A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
124      no Python state associated with it).
125    """
126    pass
127
128
129class PythonStringStateSaveable(PythonStateSaveable):
130  """Saves Python state in a checkpoint."""
131
132  def __init__(self, name, state_callback, restore_callback=None):
133    """Configure saving.
134
135    Args:
136      name: The checkpoint key to write to.
137      state_callback: A function taking no arguments which returns a string.
138        This function is run every time a checkpoint is written.
139      restore_callback: A function taking a Python string, used to restore
140        state. Optional; defaults to doing nothing, in which case it is ignored
141        by status assertions such as assert_consumed().
142    """
143    self._has_trivial_state_callback = (restore_callback is None)
144
145    def _state_callback_wrapper():
146      with ops.init_scope():
147        return state_callback()
148
149    self._state_callback = _state_callback_wrapper
150    self._restore_callback = restore_callback
151    with ops.device("/cpu:0"):
152      self._save_string = constant_op.constant("", dtype=dtypes.string)
153    spec = saveable_object.SaveSpec(
154        self._save_string, "", name, dtype=dtypes.string)
155    super(PythonStringStateSaveable, self).__init__(self._save_string, [spec],
156                                                    name)
157
158  @property
159  def optional_restore(self):
160    """For values with no restore, relaxes assert_consumed()."""
161    return self._has_trivial_state_callback
162
163  def feed_dict_additions(self):
164    """When running a graph, indicates fresh state to feed."""
165    return {self._save_string: self._state_callback()}
166
167  def freeze(self):
168    """Create a frozen `SaveableObject` which saves the current state."""
169
170    def _constant_state():
171      return constant_op.constant(self._state_callback(), dtype=dtypes.string)
172
173    return NoRestoreSaveable(
174        tensor=_constant_state,
175        dtype=dtypes.string,
176        name=self.name,
177        device="cpu:0")
178
179  def python_restore(self, restored_strings):
180    """Called to restore Python state."""
181    if self._restore_callback:
182      restored, = restored_strings
183      self._restore_callback(restored)
184
185  def restore(self, restored_tensors, restored_shapes):
186    """Called to restore TensorFlow state (nothing to do)."""
187    return control_flow_ops.no_op()
188
189
190class CheckpointPosition(object):
191  """Indicates a position within a `_CheckpointRestoreCoordinator`."""
192
193  def __init__(self, checkpoint, proto_id):
194    """Specify an object within a checkpoint.
195
196    Args:
197      checkpoint: A _CheckpointRestoreCoordinator object.
198      proto_id: The index of this object in TrackableObjectGraph.nodes.
199    """
200    self._checkpoint = checkpoint
201    self._proto_id = proto_id
202
203  def restore(self, trackable):
204    """Restore this value into `trackable`."""
205    with ops.init_scope():
206      if self.bind_object(trackable):
207        # This object's correspondence with a checkpointed object is new, so
208        # process deferred restorations for it and its dependencies.
209        restore_ops = trackable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
210        if restore_ops:
211          self._checkpoint.new_restore_ops(restore_ops)
212
213  def bind_object(self, trackable):
214    """Set a checkpoint<->object correspondence and process slot variables.
215
216    Args:
217      trackable: The object to record a correspondence for.
218
219    Returns:
220      True if this is a new assignment, False if this object has already been
221      mapped to a checkpointed `Object` proto.
222    Raises:
223      AssertionError: If another object is already bound to the `Object` proto.
224    """
225    checkpoint = self.checkpoint
226    checkpoint.all_python_objects.add(trackable)
227    current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
228    checkpoint.matched_proto_ids.add(self._proto_id)
229    if current_assignment is None:
230      checkpoint.object_by_proto_id[self._proto_id] = trackable
231      for deferred_slot_restoration in (
232          checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
233        trackable._create_or_restore_slot_variable(  # pylint: disable=protected-access
234            slot_variable_position=CheckpointPosition(
235                checkpoint=checkpoint,
236                proto_id=deferred_slot_restoration.slot_variable_id),
237            variable=deferred_slot_restoration.original_variable,
238            slot_name=deferred_slot_restoration.slot_name)
239      for slot_restoration in checkpoint.slot_restorations.pop(
240          self._proto_id, ()):
241        optimizer_object = checkpoint.object_by_proto_id.get(
242            slot_restoration.optimizer_id, None)
243        if optimizer_object is None:
244          # The optimizer has not yet been created or tracked. Record in the
245          # checkpoint that the slot variables need to be restored when it is.
246          checkpoint.deferred_slot_restorations.setdefault(
247              slot_restoration.optimizer_id, []).append(
248                  _DeferredSlotVariableRestoration(
249                      original_variable=trackable,
250                      slot_variable_id=slot_restoration.slot_variable_id,
251                      slot_name=slot_restoration.slot_name))
252        else:
253          optimizer_object._create_or_restore_slot_variable(  # pylint: disable=protected-access
254              slot_variable_position=CheckpointPosition(
255                  checkpoint=checkpoint,
256                  proto_id=slot_restoration.slot_variable_id),
257              variable=trackable,
258              slot_name=slot_restoration.slot_name)
259      return True  # New assignment
260    else:
261      # The object was already mapped for this checkpoint load, which means
262      # we don't need to do anything besides check that the mapping is
263      # consistent (if the dependency DAG is not a tree then there are
264      # multiple paths to the same object).
265      if current_assignment is not trackable:
266        logging.warning((
267            "Inconsistent references when loading the checkpoint into this "
268            "object graph. Either the Trackable object references in the "
269            "Python program have changed in an incompatible way, or the "
270            "checkpoint was generated in an incompatible program.\n\nTwo "
271            "checkpoint references resolved to different objects (%s and %s)."),
272                        current_assignment, trackable)
273      return False  # Not a new assignment
274
275  def is_simple_variable(self):
276    """Determine whether this value is restorable with a Tensor initializer."""
277    attributes = self.object_proto.attributes
278    return (len(attributes) == 1 and
279            attributes[0].name == VARIABLE_VALUE_KEY and
280            not self.object_proto.children)
281
282  def value_tensors(self):
283    """Create value `Tensor`s for this object's attributes.
284
285    Does not require that the Python object has been created. Used for
286    restore-on-create when executing eagerly.
287
288    Returns:
289      A dictionary mapping from object attribute names to `Tensor`s.
290    """
291    value_tensors = {}
292    for serialized_tensor in self.object_proto.attributes:
293      checkpoint_key = serialized_tensor.checkpoint_key
294      dtype = self._checkpoint.dtype_map[checkpoint_key]
295      base_type = dtype.base_dtype
296      with ops.init_scope():
297        with ops.device("/cpu:0"):
298          # Run the restore itself on the CPU.
299          value, = io_ops.restore_v2(
300              prefix=self._checkpoint.save_path_tensor,
301              tensor_names=[checkpoint_key],
302              shape_and_slices=[""],
303              dtypes=[base_type],
304              name="%s_checkpoint_read" % (serialized_tensor.name,))
305        # Copy the value to the current device if necessary.
306        value_tensors[serialized_tensor.name] = array_ops.identity(value)
307      return value_tensors
308
309  def gather_ops_or_named_saveables(self):
310    """Looks up or creates SaveableObjects which don't have cached ops."""
311    saveables = self.trackable._gather_saveables_for_checkpoint()  # pylint: disable=protected-access
312    # Name saveables based on the name this object had when it was checkpointed.
313    named_saveables = {}
314    python_saveables = []
315    existing_restore_ops = []
316    for serialized_tensor in self.object_proto.attributes:
317      if context.executing_eagerly():
318        existing_op = None
319      else:
320        existing_op = self._checkpoint.restore_ops_by_name.get(
321            serialized_tensor.checkpoint_key, None)
322      if existing_op is not None:
323        existing_restore_ops.append(existing_op)
324        continue
325
326      # Only if we don't have cached ops for this SaveableObject, we'll see if
327      # the SaveableObject itself has been cached. If not, we'll make it, and
328      # either way we'll extract new ops from it (or if it has Python state to
329      # restore, we'll run that).
330      saveables_cache = self._checkpoint.graph_view.saveables_cache
331      if saveables_cache is None:
332        # No SaveableObject caching when executing eagerly.
333        saveable = None
334      else:
335        # If we've already created and cached a SaveableObject for this
336        # attribute, we can re-use it to avoid re-creating some ops when graph
337        # building.
338        saveable_list = saveables_cache.get(self.trackable,
339                                            {}).get(serialized_tensor.name,
340                                                    (None,))
341        if len(saveable_list) == 1:
342          # Almost every attribute will have exactly one SaveableObject.
343          saveable, = saveable_list
344        else:
345          # Don't use cached SaveableObjects for partitioned variables, which is
346          # the only case where we'd have a list of SaveableObjects. Op caching
347          # will catch them.
348          saveable = None
349      if saveable is not None:
350        # The name of this attribute has changed, so we need to re-generate
351        # the SaveableObject.
352        if serialized_tensor.checkpoint_key not in saveable.name:
353          saveable = None
354          del saveables_cache[self.trackable]
355      if saveable is None:
356        # If there was no cached SaveableObject, we should check if the Python
357        # object has the attribute.
358        saveable_factory = saveables.get(serialized_tensor.name, None)
359        if saveable_factory is None:
360          # Purposefully does not throw an exception if attributes have been
361          # added or deleted. Stores unused attributes so an exception can be
362          # raised if the user decides to check that everything in the
363          # checkpoint was loaded.
364          if not serialized_tensor.optional_restore:
365            self._checkpoint.unused_attributes.setdefault(
366                self._proto_id, []).append(serialized_tensor.name)
367          continue
368        if callable(saveable_factory):
369          saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
370        else:
371          saveable = saveable_factory
372        if saveables_cache is not None:
373          saveables_cache.setdefault(self.trackable,
374                                     {})[serialized_tensor.name] = [saveable]
375      if isinstance(saveable, PythonStateSaveable):
376        python_saveables.append(saveable)
377      else:
378        named_saveables[serialized_tensor.checkpoint_key] = saveable
379    return existing_restore_ops, named_saveables, python_saveables
380
381  def restore_ops(self):
382    """Create or fetch restore ops for this object's attributes.
383
384    Requires that the `Trackable` Python object has been bound to an object
385    ID in the checkpoint.
386
387    Returns:
388      A list of operations when graph building, or an empty list when executing
389      eagerly.
390    """
391    (restore_ops, tensor_saveables,
392     python_saveables) = self.gather_ops_or_named_saveables()
393    restore_ops.extend(
394        self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
395    return restore_ops
396
397  @property
398  def checkpoint(self):
399    return self._checkpoint
400
401  @property
402  def trackable(self):
403    return self._checkpoint.object_by_proto_id[self._proto_id]
404
405  @property
406  def object_proto(self):
407    return self._checkpoint.object_graph_proto.nodes[self._proto_id]
408
409  @property
410  def restore_uid(self):
411    return self._checkpoint.restore_uid
412
413  def __repr__(self):
414    return repr(self.object_proto)
415
416
417_DeferredSlotVariableRestoration = collections.namedtuple(
418    "_DeferredSlotVariableRestoration", [
419        "original_variable",
420        "slot_variable_id",
421        "slot_name",
422    ])
423
424_SlotVariableRestoration = collections.namedtuple(
425    "_SlotVariableRestoration",
426    [
427        # The checkpoint proto id of the optimizer object.
428        "optimizer_id",
429        # The checkpoint proto id of the slot variable.
430        "slot_variable_id",
431        "slot_name",
432    ])
433
434
435def no_automatic_dependency_tracking(method):
436  """Disables automatic dependency tracking on attribute assignment.
437
438  Use to decorate any method of a Trackable object. Attribute assignment in
439  that method will not add dependencies (also respected in Model). Harmless if
440  used in a class which does not do automatic dependency tracking (which means
441  it's safe to use in base classes which may have subclasses which also inherit
442  from Trackable).
443
444  Args:
445    method: The method to decorate.
446
447  Returns:
448    A decorated method which sets and un-sets automatic dependency tracking for
449    the object the method is called on (not thread safe).
450  """
451
452  def _method_wrapper(self, *args, **kwargs):
453    previous_value = getattr(self, "_self_setattr_tracking", True)
454    self._self_setattr_tracking = False  # pylint: disable=protected-access
455    try:
456      result = method(self, *args, **kwargs)
457    finally:
458      self._self_setattr_tracking = previous_value  # pylint: disable=protected-access
459    return result
460
461  return tf_decorator.make_decorator(
462      target=method, decorator_func=_method_wrapper)
463
464
465@tf_contextlib.contextmanager
466def no_manual_dependency_tracking_scope(obj):
467  """A context that disables manual dependency tracking for the given `obj`.
468
469  Sometimes library methods might track objects on their own and we might want
470  to disable that and do the tracking on our own. One can then use this context
471  manager to disable the tracking the library method does and do your own
472  tracking.
473
474  For example:
475
476  class TestLayer(tf.keras.Layer):
477    def build():
478      with no_manual_dependency_tracking_scope(self):
479        var = self.add_variable("name1")  # Creates a var and doesn't track it
480      self._track_trackable("name2", var)  # We track variable with name `name2`
481
482  Args:
483    obj: A trackable object.
484
485  Yields:
486    a scope in which the object doesn't track dependencies manually.
487  """
488  # pylint: disable=protected-access
489  previous_value = getattr(obj, "_manual_tracking", True)
490  obj._manual_tracking = False
491  try:
492    yield
493  finally:
494    obj._manual_tracking = previous_value
495
496
497@tf_contextlib.contextmanager
498def no_automatic_dependency_tracking_scope(obj):
499  """A context that disables automatic dependency tracking when assigning attrs.
500
501  Objects that inherit from Autotrackable automatically creates dependencies
502  to trackable objects through attribute assignments, and wraps data structures
503  (lists or dicts) with trackable classes. This scope may be used to temporarily
504  disable this behavior. This works similar to the decorator
505  `no_automatic_dependency_tracking`.
506
507  Example usage:
508  ```
509  model = tf.keras.Model()
510  model.arr1 = []  # Creates a ListWrapper object
511  with no_automatic_dependency_tracking_scope(model):
512    model.arr2 = []  # Creates a regular, untracked python list
513  ```
514
515  Args:
516    obj: A trackable object.
517
518  Yields:
519    a scope in which the object doesn't track dependencies.
520  """
521  previous_value = getattr(obj, "_setattr_tracking", True)
522  obj._setattr_tracking = False  # pylint: disable=protected-access
523  try:
524    yield
525  finally:
526    obj._setattr_tracking = previous_value  # pylint: disable=protected-access
527
528
529class Trackable(object):
530  """Base class for `Trackable` objects without automatic dependencies.
531
532  This class has no __setattr__ override for performance reasons. Dependencies
533  must be added explicitly. Unless attribute assignment is performance-critical,
534  use `AutoTrackable` instead. Use `Trackable` for `isinstance`
535  checks.
536  """
537
538  # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with
539  # _self_. We have some properties to forward semi-public attributes to their
540  # _self_ equivalents.
541
542  @property
543  def _setattr_tracking(self):
544    if not hasattr(self, "_self_setattr_tracking"):
545      self._self_setattr_tracking = True
546    return self._self_setattr_tracking
547
548  @_setattr_tracking.setter
549  def _setattr_tracking(self, value):
550    self._self_setattr_tracking = value
551
552  @property
553  def _update_uid(self):
554    return self._self_update_uid
555
556  @_update_uid.setter
557  def _update_uid(self, value):
558    self._self_update_uid = value
559
560  @property
561  def _unconditional_checkpoint_dependencies(self):
562    return self._self_unconditional_checkpoint_dependencies
563
564  @property
565  def _unconditional_dependency_names(self):
566    return self._self_unconditional_dependency_names
567
568  @property
569  def _name_based_restores(self):
570    return self._self_name_based_restores
571
572  # Trackable does not do automatic dependency tracking, but uses the
573  # no_automatic_dependency_tracking decorator so it can avoid adding
574  # dependencies if a subclass is Trackable / inherits from Model (both of
575  # which have __setattr__ overrides).
576  @no_automatic_dependency_tracking
577  def _maybe_initialize_trackable(self):
578    """Initialize dependency management.
579
580    Not __init__, since most objects will forget to call it.
581    """
582    if hasattr(self, "_self_unconditional_checkpoint_dependencies"):
583      # __init__ already called. This check means that we don't need
584      # Trackable.__init__() in the constructor of every TensorFlow object.
585      return
586    # A list of TrackableReference objects. Some classes implementing
587    # `Trackable`, notably `Optimizer`s, may override the
588    # _checkpoint_dependencies property with conditional dependencies
589    # (e.g. based on the current graph when saving).
590    self._self_unconditional_checkpoint_dependencies = []
591    # Maps names -> Trackable objects
592    self._self_unconditional_dependency_names = {}
593    # Restorations for other Trackable objects on which this object may
594    # eventually depend. Maps local name -> CheckpointPosition list. Optimizers
595    # tack on conditional dependencies, and so need separate management of
596    # deferred dependencies too.
597    self._self_unconditional_deferred_dependencies = {}
598    # The UID of the highest assignment to this object. Used to ensure that the
599    # last requested assignment determines the final value of an object.
600    if hasattr(self, "_self_update_uid"):
601      raise AssertionError(
602          "Internal error: the object had an update UID set before its "
603          "initialization code was run.")
604    self._self_update_uid = -1
605    # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator
606    # instances, which should be checked when creating variables or other
607    # saveables. These are passed on recursively to all dependencies, since
608    # unlike object-based checkpoint restores we don't know which subgraph is
609    # being restored in advance. This mechanism is only necessary for
610    # restore-on-create when executing eagerly, and so is unused when graph
611    # building.
612    self._self_name_based_restores = set()
613
614  @property
615  def _object_identifier(self):
616    """String used to identify this object in a SavedModel.
617
618    Generally, the object identifier is constant across objects of the same
619    class, while the metadata field is used for instance-specific data.
620
621    Returns:
622      String object identifier.
623    """
624    return "_generic_user_object"
625
626  @property
627  def _tracking_metadata(self):
628    """String containing object metadata, which is saved to the SavedModel."""
629    return ""
630
631  def _no_dependency(self, value):
632    """If automatic dependency tracking is enabled, ignores `value`."""
633    return value
634
635  def _name_based_attribute_restore(self, checkpoint):
636    """Restore the object's attributes from a name-based checkpoint."""
637    self._self_name_based_restores.add(checkpoint)
638    if self._self_update_uid < checkpoint.restore_uid:
639      checkpoint.eager_restore(self)
640      self._self_update_uid = checkpoint.restore_uid
641
642  @property
643  def _checkpoint_dependencies(self):
644    """All dependencies of this object.
645
646    May be overridden to include conditional dependencies.
647
648    Returns:
649      A list of `TrackableReference` objects indicating named
650      `Trackable` dependencies which should be saved along with this
651      object.
652    """
653    return self._self_unconditional_checkpoint_dependencies
654
655  @property
656  def _deferred_dependencies(self):
657    """A dictionary with deferred dependencies.
658
659    Stores restorations for other Trackable objects on which this object
660    may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
661    conditional dependencies based the current graph, and so need separate
662    management of deferred dependencies too).
663
664    Returns:
665      A dictionary mapping from local name to a list of CheckpointPosition
666      objects.
667    """
668    return self._self_unconditional_deferred_dependencies
669
670  def _lookup_dependency(self, name):
671    """Look up a dependency by name.
672
673    May be overridden to include conditional dependencies.
674
675    Args:
676      name: The local name of the dependency.
677
678    Returns:
679      A `Trackable` object, or `None` if no dependency by this name was
680      found.
681    """
682    return self._self_unconditional_dependency_names.get(name, None)
683
684  def _add_variable_with_custom_getter(self,
685                                       name,
686                                       shape=None,
687                                       dtype=dtypes.float32,
688                                       initializer=None,
689                                       getter=None,
690                                       overwrite=False,
691                                       **kwargs_for_getter):
692    """Restore-on-create for a variable be saved with this `Trackable`.
693
694    If the user has requested that this object or another `Trackable` which
695    depends on this object be restored from a checkpoint (deferred loading
696    before variable object creation), `initializer` may be ignored and the value
697    from the checkpoint used instead.
698
699    Args:
700      name: A name for the variable. Must be unique within this object.
701      shape: The shape of the variable.
702      dtype: The data type of the variable.
703      initializer: The initializer to use. Ignored if there is a deferred
704        restoration left over from a call to
705        `_restore_from_checkpoint_position`.
706      getter: The getter to wrap which actually fetches the variable.
707      overwrite: If True, disables unique name and type checks.
708      **kwargs_for_getter: Passed to the getter.
709
710    Returns:
711      The new variable object.
712
713    Raises:
714      ValueError: If the variable name is not unique.
715    """
716    self._maybe_initialize_trackable()
717    with ops.init_scope():
718      if context.executing_eagerly():
719        # If this is a variable with a single Tensor stored in the checkpoint,
720        # we can set that value as an initializer rather than initializing and
721        # then assigning (when executing eagerly). This call returns None if
722        # there is nothing to restore.
723        checkpoint_initializer = self._preload_simple_restoration(
724            name=name, shape=shape)
725      else:
726        checkpoint_initializer = None
727      if (checkpoint_initializer is not None and
728          not (isinstance(initializer, CheckpointInitialValue) and
729               (initializer.restore_uid > checkpoint_initializer.restore_uid))):
730        # If multiple Trackable objects are "creating" the same variable
731        # via the magic of custom getters, the one with the highest restore UID
732        # (the one called last) has to make the final initializer. If another
733        # custom getter interrupts this process by overwriting the initializer,
734        # then we'll catch that when we call _track_trackable. So this is
735        # "best effort" to set the initializer with the highest restore UID.
736        initializer = checkpoint_initializer
737        shape = None
738    new_variable = getter(
739        name=name,
740        shape=shape,
741        dtype=dtype,
742        initializer=initializer,
743        **kwargs_for_getter)
744
745    # If we set an initializer and the variable processed it, tracking will not
746    # assign again. It will add this variable to our dependencies, and if there
747    # is a non-trivial restoration queued, it will handle that. This also
748    # handles slot variables.
749    if not overwrite or isinstance(new_variable, Trackable):
750      return self._track_trackable(new_variable, name=name, overwrite=overwrite)
751    else:
752      # TODO(allenl): Some variable types are not yet supported. Remove this
753      # fallback once all get_variable() return types are Trackable.
754      return new_variable
755
756  def _preload_simple_restoration(self, name, shape):
757    """Return a dependency's value for restore-on-create.
758
759    Note the restoration is not deleted; if for some reason preload is called
760    and then not assigned to the variable (for example because a custom getter
761    overrides the initializer), the assignment will still happen once the
762    variable is tracked (determined based on checkpoint.restore_uid).
763
764    Args:
765      name: The object-local name of the dependency holding the variable's
766        value.
767      shape: The shape of the variable being loaded into.
768
769    Returns:
770      An callable for use as a variable's initializer/initial_value, or None if
771      one should not be set (either because there was no variable with this name
772      in the checkpoint or because it needs more complex deserialization). Any
773      non-trivial deserialization will happen when the variable object is
774      tracked.
775    """
776    deferred_dependencies_list = self._deferred_dependencies.get(name, ())
777    if not deferred_dependencies_list:
778      # Nothing to do; we don't have a restore for this dependency queued up.
779      return
780    for checkpoint_position in deferred_dependencies_list:
781      if not checkpoint_position.is_simple_variable():
782        # If _any_ pending restoration is too complicated to fit in an
783        # initializer (because it has dependencies, or because there are
784        # multiple Tensors to restore), bail and let the general tracking code
785        # handle it.
786        return None
787    checkpoint_position = max(
788        deferred_dependencies_list,
789        key=lambda restore: restore.checkpoint.restore_uid)
790    return CheckpointInitialValue(
791        checkpoint_position=checkpoint_position, shape=shape)
792
793  def _track_trackable(self, trackable, name, overwrite=False):
794    """Declare a dependency on another `Trackable` object.
795
796    Indicates that checkpoints for this object should include variables from
797    `trackable`.
798
799    Variables in a checkpoint are mapped to `Trackable`s based on the names
800    provided when the checkpoint was written. To avoid breaking existing
801    checkpoints when modifying a class, neither variable names nor dependency
802    names (the names passed to `_track_trackable`) may change.
803
804    Args:
805      trackable: A `Trackable` which this object depends on.
806      name: A local name for `trackable`, used for loading checkpoints into the
807        correct objects.
808      overwrite: Boolean, whether silently replacing dependencies is OK. Used
809        for __setattr__, where throwing an error on attribute reassignment would
810        be inappropriate.
811
812    Returns:
813      `trackable`, for convenience when declaring a dependency and
814      assigning to a member variable in one statement.
815
816    Raises:
817      TypeError: If `trackable` does not inherit from `Trackable`.
818      ValueError: If another object is already tracked by this name.
819    """
820    self._maybe_initialize_trackable()
821    if not isinstance(trackable, Trackable):
822      raise TypeError(("Trackable._track_trackable() passed type %s, not a "
823                       "Trackable.") % (type(trackable),))
824    if not getattr(self, "_manual_tracking", True):
825      return trackable
826    new_reference = TrackableReference(name=name, ref=trackable)
827    current_object = self._lookup_dependency(name)
828    if (current_object is not None and current_object is not trackable):
829      if not overwrite:
830        raise ValueError(
831            ("Called Trackable._track_trackable() with name='%s', "
832             "but a Trackable with this name is already declared as a "
833             "dependency. Names must be unique (or overwrite=True).") % (name,))
834      # This is a weird thing to do, but we're not going to stop people from
835      # using __setattr__.
836      for index, (old_name, _) in enumerate(
837          self._self_unconditional_checkpoint_dependencies):
838        if name == old_name:
839          self._self_unconditional_checkpoint_dependencies[
840              index] = new_reference
841    elif current_object is None:
842      self._self_unconditional_checkpoint_dependencies.append(new_reference)
843      self._handle_deferred_dependencies(name=name, trackable=trackable)
844    self._self_unconditional_dependency_names[name] = trackable
845    return trackable
846
847  def _handle_deferred_dependencies(self, name, trackable):
848    """Pop and load any deferred checkpoint restores into `trackable`.
849
850    This method does not add a new dependency on `trackable`, but it does
851    check if any outstanding/deferred dependencies have been queued waiting for
852    this dependency to be added (matched based on `name`). If so,
853    `trackable` and its dependencies are restored. The restorations are
854    considered fulfilled and so are deleted.
855
856    `_track_trackable` is more appropriate for adding a
857    normal/unconditional dependency, and includes handling for deferred
858    restorations. This method allows objects such as `Optimizer` to use the same
859    restoration logic while managing conditional dependencies themselves, by
860    overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the
861    object's dependencies based on the context it is saved/restored in (a single
862    optimizer instance can have state associated with multiple graphs).
863
864    Args:
865      name: The name of the dependency within this object (`self`), used to
866        match `trackable` with values saved in a checkpoint.
867      trackable: The Trackable object to restore (inheriting from `Trackable`).
868    """
869    self._maybe_initialize_trackable()
870    trackable._maybe_initialize_trackable()  # pylint: disable=protected-access
871    deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
872    for checkpoint_position in sorted(
873        deferred_dependencies_list,
874        key=lambda restore: restore.checkpoint.restore_uid,
875        reverse=True):
876      checkpoint_position.restore(trackable)
877
878    # Pass on any name-based restores queued in this object.
879    for name_based_restore in sorted(
880        self._self_name_based_restores,
881        key=lambda checkpoint: checkpoint.restore_uid,
882        reverse=True):
883      trackable._name_based_attribute_restore(name_based_restore)  # pylint: disable=protected-access
884
885  def _restore_from_checkpoint_position(self, checkpoint_position):
886    """Restore this object and its dependencies (may be deferred)."""
887    # Attempt a breadth-first traversal, since presumably the user has more
888    # control over shorter paths. If we don't have all of the dependencies at
889    # this point, the end result is not breadth-first (since other deferred
890    # traversals will happen later).
891    visit_queue = collections.deque([checkpoint_position])
892    restore_ops = []
893    tensor_saveables = {}
894    python_saveables = []
895    while visit_queue:
896      current_position = visit_queue.popleft()
897      new_restore_ops, new_tensor_saveables, new_python_saveables = (
898          current_position.trackable  # pylint: disable=protected-access
899          ._single_restoration_from_checkpoint_position(
900              checkpoint_position=current_position,
901              visit_queue=visit_queue))
902      restore_ops.extend(new_restore_ops)
903      tensor_saveables.update(new_tensor_saveables)
904      python_saveables.extend(new_python_saveables)
905    restore_ops.extend(
906        current_position.checkpoint.restore_saveables(
907            tensor_saveables, python_saveables))
908    return restore_ops
909
910  def _single_restoration_from_checkpoint_position(self, checkpoint_position,
911                                                   visit_queue):
912    """Restore this object, and either queue its dependencies or defer them."""
913    self._maybe_initialize_trackable()
914    checkpoint = checkpoint_position.checkpoint
915    # If the UID of this restore is lower than our current update UID, we don't
916    # need to actually restore the object. However, we should pass the
917    # restoration on to our dependencies.
918    if checkpoint.restore_uid > self._self_update_uid:
919      restore_ops, tensor_saveables, python_saveables = (
920          checkpoint_position.gather_ops_or_named_saveables())
921      self._self_update_uid = checkpoint.restore_uid
922    else:
923      restore_ops = ()
924      tensor_saveables = {}
925      python_saveables = ()
926    for child in checkpoint_position.object_proto.children:
927      child_position = CheckpointPosition(
928          checkpoint=checkpoint, proto_id=child.node_id)
929      local_object = self._lookup_dependency(child.local_name)
930      if local_object is None:
931        # We don't yet have a dependency registered with this name. Save it
932        # in case we do.
933        self._deferred_dependencies.setdefault(child.local_name,
934                                               []).append(child_position)
935      else:
936        if child_position.bind_object(trackable=local_object):
937          # This object's correspondence is new, so dependencies need to be
938          # visited. Delay doing it so that we get a breadth-first dependency
939          # resolution order (shallowest paths first). The caller is responsible
940          # for emptying visit_queue.
941          visit_queue.append(child_position)
942    return restore_ops, tensor_saveables, python_saveables
943
944  def _gather_saveables_for_checkpoint(self):
945    """Returns a dictionary of values to checkpoint with this object.
946
947    Keys in the returned dictionary are local to this object and in a separate
948    namespace from dependencies. Values may either be `SaveableObject` factories
949    or variables easily converted to `SaveableObject`s (as in
950    `tf.compat.v1.train.Saver`'s
951    `var_list` constructor argument).
952
953    `SaveableObjects` have a name set, which Trackable needs to generate
954    itself. So rather than returning `SaveableObjects` directly, this method
955    should return a dictionary of callables which take `name` arguments and
956    return `SaveableObjects` with that name.
957
958    If this object may also be passed to the global-name-based
959    `tf.compat.v1.train.Saver`,
960    the returned callables should have a default value for their name argument
961    (i.e. be callable with no arguments).
962
963    Returned values must be saved only by this object; if any value may be
964    shared, it should instead be a dependency. For example, variable objects
965    save their own values with the key `VARIABLE_VALUE_KEY`, but objects which
966    reference variables simply add a dependency.
967
968    Returns:
969      The dictionary mapping attribute names to `SaveableObject` factories
970      described above. For example:
971      {VARIABLE_VALUE_KEY:
972       lambda name="global_name_for_this_object":
973       SaveableObject(name=name, ...)}
974    """
975    return {}
976
977  def _list_extra_dependencies_for_serialization(self, serialization_cache):
978    """Lists extra dependencies to serialize.
979
980    Internal sub-classes can override this method to return extra dependencies
981    that should be saved with the object during SavedModel serialization. For
982    example, this is used to save `trainable_variables` in Keras models. The
983    python property `trainable_variables` contains logic to iterate through the
984    weights from the model and its sublayers. The serialized Keras model saves
985    `trainable_weights` as a trackable list of variables.
986
987    PLEASE NOTE when overriding this method:
988      1. This function may only generate new trackable objects the first time it
989         is called.
990      2. The returned dictionary must not have naming conflicts with
991         dependencies tracked by the root. In other words, if the root is
992         tracking `object_1` with name 'x', and this functions returns
993         `{'x': object_2}`, an error is raised when saving.
994
995    Args:
996      serialization_cache: A dictionary shared between all objects in the same
997        object graph. This object is passed to both
998        `_list_extra_dependencies_for_serialization` and
999        `_list_functions_for_serialization`.
1000
1001    Returns:
1002      A dictionary mapping attribute names to trackable objects.
1003    """
1004    del serialization_cache
1005    return dict()
1006
1007  def _list_functions_for_serialization(self, serialization_cache):
1008    """Lists the functions of this trackable to serialize.
1009
1010    Internal sub-classes can override this with specific logic. E.g.
1011    `AutoTrackable` provides an implementation that returns the `attr`
1012    that return functions.
1013
1014    Args:
1015      serialization_cache: Dictionary passed to all objects in the same object
1016        graph during serialization.
1017
1018    Returns:
1019        A dictionary mapping attribute names to `Function` or
1020        `ConcreteFunction`.
1021    """
1022    del serialization_cache
1023    return dict()
1024