• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""An object-local variable management scheme."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22import functools
23import json
24import weakref
25
26import six
27
28from tensorflow.python.eager import context
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import gen_io_ops as io_ops
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.training.saving import saveable_object
37from tensorflow.python.util import nest
38from tensorflow.python.util import serialization
39from tensorflow.python.util import tf_decorator
40
41
42# Key where the object graph proto is saved in a TensorBundle
43OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
44
45
46# A key indicating a variable's value in an object's checkpointed Tensors
47# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
48# the object has no dependencies, then its value may be restored on object
49# creation (avoiding double assignment when executing eagerly).
50VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
51OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
52
53TrackableReference = collections.namedtuple(
54    "TrackableReference",
55    [
56        # The local name for this dependency.
57        "name",
58        # The Trackable object being referenced.
59        "ref"
60    ])
61
62
63class CheckpointInitialValue(ops.Tensor):
64  """Tensor wrapper for managing update UIDs in `Variables`.
65
66  When supplied as an initial value, objects of this type let a `Variable`
67  (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial
68  value came from. This allows deferred restorations to be sequenced in the
69  order the user specified them, and lets us fall back on assignment if an
70  initial value is not set (e.g. due to a custom getter interfering).
71
72  See comments in _add_variable_with_custom_getter for more information about
73  how `CheckpointInitialValue` is used.
74  """
75
76  def __init__(self, checkpoint_position, shape=None):
77    self.wrapped_value = checkpoint_position.value_tensors()[
78        VARIABLE_VALUE_KEY]
79    if shape:
80      # We need to set the static shape information on the initializer if
81      # possible so we don't get a variable with an unknown shape.
82      self.wrapped_value.set_shape(shape)
83    self._checkpoint_position = checkpoint_position
84
85  def __getattr__(self, attr):
86    try:
87      return getattr(self.wrapped_value, attr)
88    except AttributeError:
89      return self.__getattribute__(attr)
90
91  @property
92  def checkpoint_position(self):
93    return self._checkpoint_position
94
95
96class NoRestoreSaveable(saveable_object.SaveableObject):
97  """Embeds a tensor in a checkpoint with no restore ops."""
98
99  def __init__(self, tensor, name, dtype=None):
100    spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
101    super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
102
103  def restore(self, restored_tensors, restored_shapes):
104    return control_flow_ops.no_op()
105
106
107@six.add_metaclass(abc.ABCMeta)
108class PythonStateSaveable(saveable_object.SaveableObject):
109  """An interface for saving/restoring volatile Python state."""
110
111  @abc.abstractmethod
112  def feed_dict_additions(self):
113    """When running a graph, indicates fresh state to feed.
114
115    Returns:
116      A dictionary mapping `Tensor`s to current Python state.
117    """
118    pass
119
120  @abc.abstractmethod
121  def freeze(self):
122    """Create a new `SaveableObject` which freezes current state as a constant.
123
124    Used when executing eagerly to embed the current state as a constant, or
125    when creating a static tf.train.Saver with the frozen current Python state.
126
127    Returns:
128      A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
129      no Python state associated with it).
130    """
131    pass
132
133
134class PythonStringStateSaveable(PythonStateSaveable):
135  """Saves Python state in a checkpoint."""
136
137  def __init__(self, name, state_callback, restore_callback=None):
138    """Configure saving.
139
140    Args:
141      name: The checkpoint key to write to.
142      state_callback: A function taking no arguments which returns a
143        string. This function is run every time a checkpoint is written.
144      restore_callback: A function taking a Python string, used to restore
145        state. Optional; defaults to doing nothing, in which case it is ignored
146        by status assertions such as assert_consumed().
147    """
148    self._has_trivial_state_callback = (restore_callback is None)
149    def _state_callback_wrapper():
150      with ops.init_scope():
151        return state_callback()
152    self._state_callback = _state_callback_wrapper
153    self._restore_callback = restore_callback
154    with ops.device("/cpu:0"):
155      self._save_string = constant_op.constant("", dtype=dtypes.string)
156    spec = saveable_object.SaveSpec(
157        self._save_string, "", name, dtype=dtypes.string)
158    super(PythonStringStateSaveable, self).__init__(
159        self._save_string, [spec], name)
160
161  @property
162  def optional_restore(self):
163    """For values with no restore, relaxes assert_consumed()."""
164    return self._has_trivial_state_callback
165
166  def feed_dict_additions(self):
167    """When running a graph, indicates fresh state to feed."""
168    return {self._save_string: self._state_callback()}
169
170  def freeze(self):
171    """Create a frozen `SaveableObject` which saves the current state."""
172    def _constant_state():
173      return constant_op.constant(self._state_callback(), dtype=dtypes.string)
174    return NoRestoreSaveable(
175        tensor=_constant_state,
176        dtype=dtypes.string,
177        name=self.name)
178
179  def python_restore(self, restored_strings):
180    """Called to restore Python state."""
181    if self._restore_callback:
182      restored, = restored_strings
183      self._restore_callback(restored)
184
185  def restore(self, restored_tensors, restored_shapes):
186    """Called to restore TensorFlow state (nothing to do)."""
187    return control_flow_ops.no_op()
188
189
190class CheckpointPosition(object):
191  """Indicates a position within a `_CheckpointRestoreCoordinator`."""
192
193  def __init__(self, checkpoint, proto_id):
194    """Specify an object within a checkpoint.
195
196    Args:
197      checkpoint: A _CheckpointRestoreCoordinator object.
198      proto_id: The index of this object in TrackableObjectGraph.nodes.
199    """
200    self._checkpoint = checkpoint
201    self._proto_id = proto_id
202
203  def restore(self, trackable):
204    """Restore this value into `trackable`."""
205    with ops.init_scope():
206      if self.bind_object(trackable):
207        # This object's correspondence with a checkpointed object is new, so
208        # process deferred restorations for it and its dependencies.
209        restore_ops = trackable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
210        if restore_ops:
211          self._checkpoint.new_restore_ops(restore_ops)
212
213  def bind_object(self, trackable):
214    """Set a checkpoint<->object correspondence and process slot variables.
215
216    Args:
217      trackable: The object to record a correspondence for.
218    Returns:
219      True if this is a new assignment, False if this object has already been
220      mapped to a checkpointed `Object` proto.
221    Raises:
222      AssertionError: If another object is already bound to the `Object` proto.
223    """
224    checkpoint = self.checkpoint
225    checkpoint.all_python_objects.add(trackable)
226    current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
227    if current_assignment is None:
228      checkpoint.object_by_proto_id[self._proto_id] = trackable
229      for deferred_slot_restoration in (
230          checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
231        trackable._create_or_restore_slot_variable(  # pylint: disable=protected-access
232            slot_variable_position=CheckpointPosition(
233                checkpoint=checkpoint,
234                proto_id=deferred_slot_restoration.slot_variable_id),
235            variable=deferred_slot_restoration.original_variable,
236            slot_name=deferred_slot_restoration.slot_name)
237      for slot_restoration in checkpoint.slot_restorations.pop(
238          self._proto_id, ()):
239        optimizer_object = checkpoint.object_by_proto_id.get(
240            slot_restoration.optimizer_id, None)
241        if optimizer_object is None:
242          # The optimizer has not yet been created or tracked. Record in the
243          # checkpoint that the slot variables need to be restored when it is.
244          checkpoint.deferred_slot_restorations.setdefault(
245              slot_restoration.optimizer_id, []).append(
246                  _DeferredSlotVariableRestoration(
247                      original_variable=trackable,
248                      slot_variable_id=slot_restoration.slot_variable_id,
249                      slot_name=slot_restoration.slot_name))
250        else:
251          optimizer_object._create_or_restore_slot_variable(  # pylint: disable=protected-access
252              slot_variable_position=CheckpointPosition(
253                  checkpoint=checkpoint,
254                  proto_id=slot_restoration.slot_variable_id),
255              variable=trackable,
256              slot_name=slot_restoration.slot_name)
257      return True  # New assignment
258    else:
259      # The object was already mapped for this checkpoint load, which means
260      # we don't need to do anything besides check that the mapping is
261      # consistent (if the dependency DAG is not a tree then there are
262      # multiple paths to the same object).
263      if current_assignment is not trackable:
264        logging.warning(
265            ("Inconsistent references when loading the checkpoint into this "
266             "object graph. Either the Trackable object references in the "
267             "Python program have changed in an incompatible way, or the "
268             "checkpoint was generated in an incompatible program.\n\nTwo "
269             "checkpoint references resolved to different objects (%s and %s).")
270            % (current_assignment, trackable))
271      return False  # Not a new assignment
272
273  def is_simple_variable(self):
274    """Determine whether this value is restorable with a Tensor initializer."""
275    attributes = self.object_proto.attributes
276    return (len(attributes) == 1
277            and attributes[0].name == VARIABLE_VALUE_KEY
278            and not self.object_proto.children)
279
280  def value_tensors(self):
281    """Create value `Tensor`s for this object's attributes.
282
283    Does not require that the Python object has been created. Used for
284    restore-on-create when executing eagerly.
285
286    Returns:
287      A dictionary mapping from object attribute names to `Tensor`s.
288    """
289    value_tensors = {}
290    for serialized_tensor in self.object_proto.attributes:
291      checkpoint_key = serialized_tensor.checkpoint_key
292      dtype = self._checkpoint.dtype_map[checkpoint_key]
293      base_type = dtype.base_dtype
294      with ops.init_scope():
295        with ops.device("/cpu:0"):
296          # Run the restore itself on the CPU.
297          value, = io_ops.restore_v2(
298              prefix=self._checkpoint.save_path_tensor,
299              tensor_names=[checkpoint_key],
300              shape_and_slices=[""],
301              dtypes=[base_type],
302              name="%s_checkpoint_read" % (serialized_tensor.name,))
303        # Copy the value to the current device if necessary.
304        value_tensors[serialized_tensor.name] = array_ops.identity(value)
305      return value_tensors
306
307  def _gather_ops_or_named_saveables(self):
308    """Looks up or creates SaveableObjects which don't have cached ops."""
309    saveables = self.trackable._gather_saveables_for_checkpoint()  # pylint: disable=protected-access
310    # Name saveables based on the name this object had when it was checkpointed.
311    named_saveables = {}
312    python_saveables = []
313    existing_restore_ops = []
314    for serialized_tensor in self.object_proto.attributes:
315      if context.executing_eagerly():
316        existing_op = None
317      else:
318        existing_op = self._checkpoint.restore_ops_by_name.get(
319            serialized_tensor.checkpoint_key, None)
320      if existing_op is not None:
321        existing_restore_ops.append(existing_op)
322        continue
323
324      # Only if we don't have cached ops for this SaveableObject, we'll see if
325      # the SaveableObject itself has been cached. If not, we'll make it, and
326      # either way we'll extract new ops from it (or if it has Python state to
327      # restore, we'll run that).
328      saveables_cache = self._checkpoint.graph_view.saveables_cache
329      if saveables_cache is None:
330        # No SaveableObject caching when executing eagerly.
331        saveable = None
332      else:
333        # If we've already created and cached a SaveableObject for this
334        # attribute, we can re-use it to avoid re-creating some ops when graph
335        # building.
336        saveable_list = saveables_cache.get(
337            self.trackable, {}).get(serialized_tensor.name, (None,))
338        if len(saveable_list) == 1:
339          # Almost every attribute will have exactly one SaveableObject.
340          saveable, = saveable_list
341        else:
342          # Don't use cached SaveableObjects for partitioned variables, which is
343          # the only case where we'd have a list of SaveableObjects. Op caching
344          # will catch them.
345          saveable = None
346      if saveable is not None:
347        # The name of this attribute has changed, so we need to re-generate
348        # the SaveableObject.
349        if serialized_tensor.checkpoint_key not in saveable.name:
350          saveable = None
351          del saveables_cache[self.trackable]
352          break
353      if saveable is None:
354        # If there was no cached SaveableObject, we should check if the Python
355        # object has the attribute.
356        saveable_factory = saveables.get(serialized_tensor.name, None)
357        if saveable_factory is None:
358          # Purposefully does not throw an exception if attributes have been
359          # added or deleted. Stores unused attributes so an exception can be
360          # raised if the user decides to check that everything in the
361          # checkpoint was loaded.
362          if not serialized_tensor.optional_restore:
363            self._checkpoint.unused_attributes.setdefault(
364                self.trackable, []).append(serialized_tensor.name)
365          continue
366        if callable(saveable_factory):
367          saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
368        else:
369          saveable = saveable_factory
370        if saveables_cache is not None:
371          saveables_cache.setdefault(
372              self.trackable, {})[serialized_tensor.name] = [saveable]
373      if isinstance(saveable, PythonStateSaveable):
374        python_saveables.append(saveable)
375      else:
376        named_saveables[serialized_tensor.checkpoint_key] = saveable
377    return existing_restore_ops, named_saveables, python_saveables
378
379  def restore_ops(self):
380    """Create or fetch restore ops for this object's attributes.
381
382    Requires that the `Trackable` Python object has been bound to an object
383    ID in the checkpoint.
384
385    Returns:
386      A list of operations when graph building, or an empty list when executing
387      eagerly.
388    """
389    (restore_ops,
390     tensor_saveables,
391     python_saveables) = self._gather_ops_or_named_saveables()
392    restore_ops.extend(self._checkpoint.restore_saveables(
393        tensor_saveables, python_saveables))
394    return restore_ops
395
396  @property
397  def checkpoint(self):
398    return self._checkpoint
399
400  @property
401  def trackable(self):
402    return self._checkpoint.object_by_proto_id[self._proto_id]
403
404  @property
405  def object_proto(self):
406    return self._checkpoint.object_graph_proto.nodes[self._proto_id]
407
408  @property
409  def restore_uid(self):
410    return self._checkpoint.restore_uid
411
412  def __repr__(self):
413    return repr(self.object_proto)
414
415
416_DeferredSlotVariableRestoration = collections.namedtuple(
417    "_DeferredSlotVariableRestoration",
418    [
419        "original_variable",
420        "slot_variable_id",
421        "slot_name",
422    ]
423)
424
425_SlotVariableRestoration = collections.namedtuple(
426    "_SlotVariableRestoration",
427    [
428        # The checkpoint proto id of the optimizer object.
429        "optimizer_id",
430        # The checkpoint proto id of the slot variable.
431        "slot_variable_id",
432        "slot_name",
433    ])
434
435
436def no_automatic_dependency_tracking(method):
437  """Disables automatic dependency tracking on attribute assignment.
438
439  Use to decorate any method of a Trackable object. Attribute assignment in
440  that method will not add dependencies (also respected in Model). Harmless if
441  used in a class which does not do automatic dependency tracking (which means
442  it's safe to use in base classes which may have subclasses which also inherit
443  from Trackable).
444
445  Args:
446    method: The method to decorate.
447  Returns:
448    A decorated method which sets and un-sets automatic dependency tracking for
449    the object the method is called on (not thread safe).
450  """
451
452  def _method_wrapper(self, *args, **kwargs):
453    previous_value = getattr(self, "_setattr_tracking", True)
454    self._setattr_tracking = False  # pylint: disable=protected-access
455    try:
456      result = method(self, *args, **kwargs)
457    finally:
458      self._setattr_tracking = previous_value  # pylint: disable=protected-access
459    return result
460
461  return tf_decorator.make_decorator(
462      target=method, decorator_func=_method_wrapper)
463
464
465class Trackable(object):
466  """Base class for `Trackable` objects without automatic dependencies.
467
468  This class has no __setattr__ override for performance reasons. Dependencies
469  must be added explicitly. Unless attribute assignment is performance-critical,
470  use `AutoTrackable` instead. Use `Trackable` for `isinstance`
471  checks.
472  """
473
474  # Trackable does not do automatic dependency tracking, but uses the
475  # no_automatic_dependency_tracking decorator so it can avoid adding
476  # dependencies if a subclass is Trackable / inherits from Model (both of
477  # which have __setattr__ overrides).
478  @no_automatic_dependency_tracking
479  def _maybe_initialize_trackable(self):
480    """Initialize dependency management.
481
482    Not __init__, since most objects will forget to call it.
483    """
484    if hasattr(self, "_unconditional_checkpoint_dependencies"):
485      # __init__ already called. This check means that we don't need
486      # Trackable.__init__() in the constructor of every TensorFlow object.
487      return
488    # A list of TrackableReference objects. Some classes implementing
489    # `Trackable`, notably `Optimizer`s, may override the
490    # _checkpoint_dependencies property with conditional dependencies
491    # (e.g. based on the current graph when saving).
492    self._unconditional_checkpoint_dependencies = []
493    # Maps names -> Trackable objects
494    self._unconditional_dependency_names = {}
495    # Restorations for other Trackable objects on which this object may
496    # eventually depend. Maps local name -> CheckpointPosition list. Optimizers
497    # tack on conditional dependencies, and so need separate management of
498    # deferred dependencies too.
499    self._unconditional_deferred_dependencies = {}
500    # The UID of the highest assignment to this object. Used to ensure that the
501    # last requested assignment determines the final value of an object.
502    if hasattr(self, "_update_uid"):
503      raise AssertionError(
504          "Internal error: the object had an update UID set before its "
505          "initialization code was run.")
506    self._update_uid = -1
507    # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator
508    # instances, which should be checked when creating variables or other
509    # saveables. These are passed on recursively to all dependencies, since
510    # unlike object-based checkpoint restores we don't know which subgraph is
511    # being restored in advance. This mechanism is only necessary for
512    # restore-on-create when executing eagerly, and so is unused when graph
513    # building.
514    self._name_based_restores = set()
515
516  def _no_dependency(self, value):
517    """If automatic dependency tracking is enabled, ignores `value`."""
518    return value
519
520  def _name_based_attribute_restore(self, checkpoint):
521    """Restore the object's attributes from a name-based checkpoint."""
522    self._name_based_restores.add(checkpoint)
523    if self._update_uid < checkpoint.restore_uid:
524      checkpoint.eager_restore(self)
525      self._update_uid = checkpoint.restore_uid
526
527  @property
528  def _checkpoint_dependencies(self):
529    """All dependencies of this object.
530
531    May be overridden to include conditional dependencies.
532
533    Returns:
534      A list of `TrackableReference` objects indicating named
535      `Trackable` dependencies which should be saved along with this
536      object.
537    """
538    return self._unconditional_checkpoint_dependencies
539
540  @property
541  def _deferred_dependencies(self):
542    """A dictionary with deferred dependencies.
543
544    Stores restorations for other Trackable objects on which this object
545    may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
546    conditional dependencies based the current graph, and so need separate
547    management of deferred dependencies too).
548
549    Returns:
550      A dictionary mapping from local name to a list of CheckpointPosition
551      objects.
552    """
553    return self._unconditional_deferred_dependencies
554
555  def _lookup_dependency(self, name):
556    """Look up a dependency by name.
557
558    May be overridden to include conditional dependencies.
559
560    Args:
561      name: The local name of the dependency.
562    Returns:
563      A `Trackable` object, or `None` if no dependency by this name was
564      found.
565    """
566    return self._unconditional_dependency_names.get(name, None)
567
568  def _add_variable_with_custom_getter(
569      self, name, shape=None, dtype=dtypes.float32,
570      initializer=None, getter=None, overwrite=False,
571      **kwargs_for_getter):
572    """Restore-on-create for a variable be saved with this `Trackable`.
573
574    If the user has requested that this object or another `Trackable` which
575    depends on this object be restored from a checkpoint (deferred loading
576    before variable object creation), `initializer` may be ignored and the value
577    from the checkpoint used instead.
578
579    Args:
580      name: A name for the variable. Must be unique within this object.
581      shape: The shape of the variable.
582      dtype: The data type of the variable.
583      initializer: The initializer to use. Ignored if there is a deferred
584        restoration left over from a call to
585        `_restore_from_checkpoint_position`.
586      getter: The getter to wrap which actually fetches the variable.
587      overwrite: If True, disables unique name and type checks.
588      **kwargs_for_getter: Passed to the getter.
589
590    Returns:
591      The new variable object.
592
593    Raises:
594      ValueError: If the variable name is not unique.
595    """
596    self._maybe_initialize_trackable()
597    with ops.init_scope():
598      if context.executing_eagerly():
599        # If this is a variable with a single Tensor stored in the checkpoint,
600        # we can set that value as an initializer rather than initializing and
601        # then assigning (when executing eagerly). This call returns None if
602        # there is nothing to restore.
603        checkpoint_initializer = self._preload_simple_restoration(
604            name=name, shape=shape)
605      else:
606        checkpoint_initializer = None
607      if (checkpoint_initializer is not None
608          and not (
609              isinstance(initializer, CheckpointInitialValue)
610              and (initializer.restore_uid
611                   > checkpoint_initializer.restore_uid))):
612        # If multiple Trackable objects are "creating" the same variable
613        # via the magic of custom getters, the one with the highest restore UID
614        # (the one called last) has to make the final initializer. If another
615        # custom getter interrupts this process by overwriting the initializer,
616        # then we'll catch that when we call _track_trackable. So this is
617        # "best effort" to set the initializer with the highest restore UID.
618        initializer = checkpoint_initializer
619        shape = None
620    new_variable = getter(
621        name=name, shape=shape, dtype=dtype, initializer=initializer,
622        **kwargs_for_getter)
623
624    # If we set an initializer and the variable processed it, tracking will not
625    # assign again. It will add this variable to our dependencies, and if there
626    # is a non-trivial restoration queued, it will handle that. This also
627    # handles slot variables.
628    if not overwrite or isinstance(new_variable, Trackable):
629      return self._track_trackable(new_variable, name=name,
630                                   overwrite=overwrite)
631    else:
632      # TODO(allenl): Some variable types are not yet supported. Remove this
633      # fallback once all get_variable() return types are Trackable.
634      return new_variable
635
636  def _preload_simple_restoration(self, name, shape):
637    """Return a dependency's value for restore-on-create.
638
639    Note the restoration is not deleted; if for some reason preload is called
640    and then not assigned to the variable (for example because a custom getter
641    overrides the initializer), the assignment will still happen once the
642    variable is tracked (determined based on checkpoint.restore_uid).
643
644    Args:
645      name: The object-local name of the dependency holding the variable's
646        value.
647      shape: The shape of the variable being loaded into.
648    Returns:
649      An callable for use as a variable's initializer/initial_value, or None if
650      one should not be set (either because there was no variable with this name
651      in the checkpoint or because it needs more complex deserialization). Any
652      non-trivial deserialization will happen when the variable object is
653      tracked.
654    """
655    deferred_dependencies_list = self._deferred_dependencies.get(name, ())
656    if not deferred_dependencies_list:
657      # Nothing to do; we don't have a restore for this dependency queued up.
658      return
659    for checkpoint_position in deferred_dependencies_list:
660      if not checkpoint_position.is_simple_variable():
661        # If _any_ pending restoration is too complicated to fit in an
662        # initializer (because it has dependencies, or because there are
663        # multiple Tensors to restore), bail and let the general tracking code
664        # handle it.
665        return None
666    checkpoint_position = max(
667        deferred_dependencies_list,
668        key=lambda restore: restore.checkpoint.restore_uid)
669    return CheckpointInitialValue(
670        checkpoint_position=checkpoint_position, shape=shape)
671
672  def _track_trackable(self, trackable, name, overwrite=False):
673    """Declare a dependency on another `Trackable` object.
674
675    Indicates that checkpoints for this object should include variables from
676    `trackable`.
677
678    Variables in a checkpoint are mapped to `Trackable`s based on the names
679    provided when the checkpoint was written. To avoid breaking existing
680    checkpoints when modifying a class, neither variable names nor dependency
681    names (the names passed to `_track_trackable`) may change.
682
683    Args:
684      trackable: A `Trackable` which this object depends on.
685      name: A local name for `trackable`, used for loading checkpoints into
686        the correct objects.
687      overwrite: Boolean, whether silently replacing dependencies is OK. Used
688        for __setattr__, where throwing an error on attribute reassignment would
689        be inappropriate.
690
691    Returns:
692      `trackable`, for convenience when declaring a dependency and
693      assigning to a member variable in one statement.
694
695    Raises:
696      TypeError: If `trackable` does not inherit from `Trackable`.
697      ValueError: If another object is already tracked by this name.
698    """
699    self._maybe_initialize_trackable()
700    if not isinstance(trackable, Trackable):
701      raise TypeError(
702          ("Trackable._track_trackable() passed type %s, not a "
703           "Trackable.") % (type(trackable),))
704    new_reference = TrackableReference(name=name, ref=trackable)
705    current_object = self._lookup_dependency(name)
706    if (current_object is not None
707        and current_object is not trackable):
708      if not overwrite:
709        raise ValueError(
710            ("Called Trackable._track_trackable() with name='%s', "
711             "but a Trackable with this name is already declared as a "
712             "dependency. Names must be unique (or overwrite=True).") % (name,))
713      # This is a weird thing to do, but we're not going to stop people from
714      # using __setattr__.
715      for index, (old_name, _) in enumerate(
716          self._unconditional_checkpoint_dependencies):
717        if name == old_name:
718          self._unconditional_checkpoint_dependencies[index] = new_reference
719    elif current_object is None:
720      self._unconditional_checkpoint_dependencies.append(new_reference)
721      self._handle_deferred_dependencies(
722          name=name, trackable=trackable)
723    self._unconditional_dependency_names[name] = trackable
724    return trackable
725
726  def _handle_deferred_dependencies(self, name, trackable):
727    """Pop and load any deferred checkpoint restores into `trackable`.
728
729    This method does not add a new dependency on `trackable`, but it does
730    check if any outstanding/deferred dependencies have been queued waiting for
731    this dependency to be added (matched based on `name`). If so,
732    `trackable` and its dependencies are restored. The restorations are
733    considered fulfilled and so are deleted.
734
735    `_track_trackable` is more appropriate for adding a
736    normal/unconditional dependency, and includes handling for deferred
737    restorations. This method allows objects such as `Optimizer` to use the same
738    restoration logic while managing conditional dependencies themselves, by
739    overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the
740    object's dependencies based on the context it is saved/restored in (a single
741    optimizer instance can have state associated with multiple graphs).
742
743    Args:
744      name: The name of the dependency within this object (`self`), used to
745        match `trackable` with values saved in a checkpoint.
746      trackable: The Trackable object to restore (inheriting from
747        `Trackable`).
748    """
749    self._maybe_initialize_trackable()
750    trackable._maybe_initialize_trackable()  # pylint: disable=protected-access
751    deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
752    for checkpoint_position in sorted(
753        deferred_dependencies_list,
754        key=lambda restore: restore.checkpoint.restore_uid,
755        reverse=True):
756      checkpoint_position.restore(trackable)
757
758    # Pass on any name-based restores queued in this object.
759    for name_based_restore in sorted(
760        self._name_based_restores,
761        key=lambda checkpoint: checkpoint.restore_uid,
762        reverse=True):
763      trackable._name_based_attribute_restore(name_based_restore)  # pylint: disable=protected-access
764
765  def _restore_from_checkpoint_position(self, checkpoint_position):
766    """Restore this object and its dependencies (may be deferred)."""
767    # Attempt a breadth-first traversal, since presumably the user has more
768    # control over shorter paths. If we don't have all of the dependencies at
769    # this point, the end result is not breadth-first (since other deferred
770    # traversals will happen later).
771    visit_queue = collections.deque([checkpoint_position])
772    restore_ops = []
773    while visit_queue:
774      current_position = visit_queue.popleft()
775      restore_ops.extend(nest.flatten(
776          current_position.trackable  # pylint: disable=protected-access
777          ._single_restoration_from_checkpoint_position(
778              checkpoint_position=current_position,
779              visit_queue=visit_queue)))
780    return restore_ops
781
782  def _single_restoration_from_checkpoint_position(
783      self, checkpoint_position, visit_queue):
784    """Restore this object, and either queue its dependencies or defer them."""
785    self._maybe_initialize_trackable()
786    checkpoint = checkpoint_position.checkpoint
787    # If the UID of this restore is lower than our current update UID, we don't
788    # need to actually restore the object. However, we should pass the
789    # restoration on to our dependencies.
790    if checkpoint.restore_uid > self._update_uid:
791      restore_ops = checkpoint_position.restore_ops()
792      self._update_uid = checkpoint.restore_uid
793    else:
794      restore_ops = ()
795    for child in checkpoint_position.object_proto.children:
796      child_position = CheckpointPosition(
797          checkpoint=checkpoint,
798          proto_id=child.node_id)
799      local_object = self._lookup_dependency(child.local_name)
800      if local_object is None:
801        # We don't yet have a dependency registered with this name. Save it
802        # in case we do.
803        self._deferred_dependencies.setdefault(child.local_name, []).append(
804            child_position)
805      else:
806        if child_position.bind_object(trackable=local_object):
807          # This object's correspondence is new, so dependencies need to be
808          # visited. Delay doing it so that we get a breadth-first dependency
809          # resolution order (shallowest paths first). The caller is responsible
810          # for emptying visit_queue.
811          visit_queue.append(child_position)
812    return restore_ops
813
814  def _gather_saveables_for_checkpoint(self):
815    """Returns a dictionary of values to checkpoint with this object.
816
817    Keys in the returned dictionary are local to this object and in a separate
818    namespace from dependencies. Values may either be `SaveableObject` factories
819    or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s
820    `var_list` constructor argument).
821
822    `SaveableObjects` have a name set, which Trackable needs to generate
823    itself. So rather than returning `SaveableObjects` directly, this method
824    should return a dictionary of callables which take `name` arguments and
825    return `SaveableObjects` with that name.
826
827    If this object may also be passed to the global-name-based `tf.train.Saver`,
828    the returned callables should have a default value for their name argument
829    (i.e. be callable with no arguments).
830
831    Returned values must be saved only by this object; if any value may be
832    shared, it should instead be a dependency. For example, variable objects
833    save their own values with the key `VARIABLE_VALUE_KEY`, but objects which
834    reference variables simply add a dependency.
835
836    Returns:
837      The dictionary mapping attribute names to `SaveableObject` factories
838      described above. For example:
839      {VARIABLE_VALUE_KEY:
840       lambda name="global_name_for_this_object":
841       SaveableObject(name=name, ...)}
842    """
843    if not hasattr(self, "get_config"):
844      return {}
845    try:
846      self.get_config()
847    except NotImplementedError:
848      return {}
849    weak_self = weakref.ref(self)
850    def _state_callback():
851      """Serializes `self.get_config()` for saving."""
852      dereferenced_self = weak_self()
853      if dereferenced_self:
854        try:
855          return json.dumps(
856              dereferenced_self,
857              default=serialization.get_json_type,
858              sort_keys=True).encode("utf8")
859        except TypeError:
860          # Even if get_config worked objects may have produced garbage.
861          return ""
862      else:
863        return ""
864    return {OBJECT_CONFIG_JSON_KEY: functools.partial(
865        PythonStringStateSaveable,
866        state_callback=_state_callback)}
867
868  def _list_functions_for_serialization(self):
869    """Lists the functions of this trackable to serialize.
870
871    Internal sub-classes can override this with specific logic. E.g.
872    `AutoTrackable` provides an implementation that returns the `attr`
873    that return functions.
874
875    Returns:
876        A dictionary mapping attribute names to `Function` or
877        `ConcreteFunction`.
878    """
879    return dict()
880