• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logic for restoring checkpointed values for Trackables."""
17import collections
19from tensorflow.python.checkpoint import saveable_compat
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import gen_io_ops as io_ops
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.saved_model import registration
26from tensorflow.python.trackable import constants
27from tensorflow.python.trackable import python_state
28from tensorflow.python.trackable import trackable_utils
31class CheckpointPosition(object):
32  """Indicates a position within a `_CheckpointRestoreCoordinator`."""
34  __slots__ = ["_checkpoint", "_proto_id", "skip_restore"]
36  def __init__(self, checkpoint, proto_id):
37    """Specify an object within a checkpoint.
39    Args:
40      checkpoint: A _CheckpointRestoreCoordinator object.
41      proto_id: The index of this object in TrackableObjectGraph.nodes.
42    """
43    self._checkpoint = checkpoint
44    self._proto_id = proto_id
45    # This may be set to True if the registered saver cannot be used with this
46    # object.
47    self.skip_restore = False
49  def restore(self, trackable):
50    """Restore this value into `trackable`."""
51    with ops.init_scope():
52      if self.bind_object(trackable):
53        # This object's correspondence with a checkpointed object is new, so
54        # process deferred restorations for it and its dependencies.
55        restore_ops = self._restore_descendants()
56        if restore_ops:
57          self._checkpoint.new_restore_ops(restore_ops)
59  def bind_object(self, trackable):
60    """Set a checkpoint<->object correspondence.
62    Args:
63      trackable: The object to record a correspondence for.
65    Returns:
66      True if this is a new assignment, False if this object has already been
67      mapped to a checkpointed `Object` proto.
68    Raises:
69      AssertionError: If another object is already bound to the `Object` proto.
70    """
71    checkpoint = self.checkpoint
72    checkpoint.all_python_objects.add(trackable)
73    current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
74    checkpoint.matched_proto_ids.add(self._proto_id)
75    if current_assignment is None:
76      checkpoint.object_by_proto_id[self._proto_id] = trackable
77      return True  # New assignment
78    else:
79      # The object was already mapped for this checkpoint load, which means
80      # we don't need to do anything besides check that the mapping is
81      # consistent (if the dependency DAG is not a tree then there are
82      # multiple paths to the same object).
83      if current_assignment is not trackable:
84        logging.warning(
85            "Inconsistent references when loading the checkpoint into this "
86            "object graph. For example, in the saved checkpoint object, "
87            "`model.layer.weight` and `model.layer_copy.weight` reference the "
88            "same variable, while in the current object these are two different"
89            " variables. The referenced variables are:"
90            f"({current_assignment} and {trackable}).")
91      return False  # Not a new assignment
93  def is_simple_variable(self):
94    """Determine whether this value is restorable with a Tensor initializer."""
95    attributes = self.object_proto.attributes
96    return (len(attributes) == 1 and
97            attributes[0].name == constants.VARIABLE_VALUE_KEY and
98            not self.object_proto.children)
100  def value_tensors(self, shape_and_slices=None):
101    """Create value `Tensor`s for this object's attributes.
103    Does not require that the Python object has been created. Used for
104    restore-on-create when executing eagerly.
106    Args:
107      shape_and_slices: A dict mapping from object attribute names to a shape
108        and slice string that will be passed to a RestoreV2 op. If the dict is
109        None or if an object attribute is not in the dict, the full tensor will
110        be restored.
112    Returns:
113      A dictionary mapping from object attribute names to `Tensor`s.
114    """
115    value_tensors = {}
116    for serialized_tensor in self.object_proto.attributes:
117      checkpoint_key = serialized_tensor.checkpoint_key
118      dtype = self._checkpoint.dtype_map[checkpoint_key]
119      base_type = dtype.base_dtype
120      io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
121      with ops.init_scope():
122        with ops.device(io_device):
123          # Run the restore itself on the io_device(CPU or specified).
124          if (shape_and_slices is not None and
125              serialized_tensor.name in shape_and_slices):
126            shape_and_slice = shape_and_slices[serialized_tensor.name]
127          else:
128            shape_and_slice = ""
129          value, = io_ops.restore_v2(
130              prefix=self._checkpoint.save_path_tensor,
131              tensor_names=[checkpoint_key],
132              shape_and_slices=[shape_and_slice],
133              dtypes=[base_type],
134              name="%s_checkpoint_read" % (serialized_tensor.name,))
135        # Copy the value to the current device if necessary.
136        value_tensors[serialized_tensor.name] = array_ops.identity(value)
137    return value_tensors
139  def gather_ops_or_named_saveables(self):
140    """Looks up or creates SaveableObjects which don't have cached ops.
142    Returns:
143      A tuple of (
144          existing_restore_ops: list,
145          named_saveables: dict,
146          python_positions: list,
147          registered_savers: dict)
148    """
149    # pylint:disable=g-import-not-at-top
150    # There are circular dependencies between Trackable and SaveableObject,
151    # so we must import it here.
152    # TODO(b/224069573): Remove this code from Trackable.
153    from tensorflow.python.training.saving import saveable_object_util
154    # pylint:enable=g-import-not-at-top
156    recorded_registered_saver = self.get_registered_saver_name()
157    if not (self.object_proto.attributes or recorded_registered_saver):
158      return [], {}, [], {}
160    existing_restore_ops = []
161    named_saveables = {}
162    python_positions = []
163    registered_savers = collections.defaultdict(dict)
165    saveable_factories = saveable_object_util.saveable_objects_from_trackable(
166        self.trackable)
167    saver_name = registration.get_registered_saver_name(self.trackable)
169    if recorded_registered_saver:
170      if not self.skip_restore:
171        name = self.object_proto.registered_saver.object_name
172        registered_savers[recorded_registered_saver][name] = self.trackable
173      # Else: Skip restoration of this Trackable. This skip only happens if the
174      # registered saver has enabled `option_restore`. Otherwise, an error would
175      # have been raised at `self.get_registered_saver_name()`.
176    elif saver_name:
177      # In this case, the checkpoint has a recorded serialized tensor but no
178      # registered saver, while the Trackable loading the checkpoint has
179      # migrated to the registered checkpoint functionality (TPUEmbedding is an
180      # example of this).
182      # Set the Trackable's object name to the first checkpoint key that is
183      # stored in checkpoint. If there is a use case that requires the other
184      # keys, then we can take another look at this.
185      registered_savers[saver_name] = {
186          self.object_proto.attributes[0].checkpoint_key: self.trackable
187      }
188    elif isinstance(self.trackable, python_state.PythonState):
189      python_positions.append(self)
190    elif saveable_factories.keys() == {
191        trackable_utils.SERIALIZE_TO_TENSORS_NAME
192    }:
193      existing_restore_ops, named_saveables = (
194          self._create_serialize_to_tensor_saveable(saveable_factories))
195    elif saveable_factories:
196      existing_restore_ops, named_saveables = (
197          self._create_saveables_by_attribute_name(saveable_factories))
198    else:
199      # If no registered savers were found, then it means that one or more
200      # serialized tensors were never used.
201      for serialized_tensor in self.object_proto.attributes:
202        self._checkpoint.unused_attributes.setdefault(
203            self._proto_id, []).append(serialized_tensor.name)
204    return (existing_restore_ops, named_saveables, python_positions,
205            registered_savers)
207  def _create_serialize_to_tensor_saveable(self, saveable_factories):
208    """Creates a saveable using the _serialize_to_tensor method."""
209    # Extract the saveable name from the checkpoint key. This will be used as
210    # the cache key or the name to pass to the saveable factory.
211    suffix = saveable_compat.get_saveable_name(self.trackable) or ""
212    saveable_name = _extract_saveable_name(
213        self.object_proto.attributes[0].checkpoint_key) + suffix
215    # Try to find the cached saveable (only in graph mode).
216    if not context.executing_eagerly():
217      existing_op = self._checkpoint.restore_ops_by_name.get(
218          saveable_name, None)
219      if existing_op is not None:
220        return existing_op, {}
222      saveables_cache = self._checkpoint.saveables_cache.setdefault(
223          self.trackable, {})
224      if saveable_name in saveables_cache:
225        return [], {saveable_name: saveables_cache[saveable_name]}
227    saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME](
228        name=saveable_name)
229    if not context.executing_eagerly():
230      saveables_cache[saveable_name] = saveable
231    return [], {saveable_name: saveable}
233  def _create_saveables_by_attribute_name(self, saveable_factories):
234    """Creates or caches SaveableObjects by matching the attribute names.
236    The attribute name keys in the `saveable_factories` is used to find the
237    corresponding attribute in the object proto. Attributes contain checkpoint
238    keys which are passed to the factory function to generate the
239    SaveableObject.
241    Args:
242      saveable_factories: a dict mapping attribute name to a callable factory
243        function that produces a SaveableObject.
245    Returns:
246      A tuple of (
247          existing_restore_ops: list,
248          named_saveables: dict)
249    """
250    # Name saveables based on the name this object had when it was checkpointed.
251    named_saveables = {}
252    existing_restore_ops = []
254    # Forward compatibility code: when loading a future checkpoint, there may
255    # be multiple SerializedTensors mapped to a single saveable.
256    created_compat_names = set()
258    for serialized_tensor in self.object_proto.attributes:
259      if context.executing_eagerly():
260        existing_op = None
261      else:
262        existing_op = self._checkpoint.restore_ops_by_name.get(
263            serialized_tensor.checkpoint_key, None)
264      if existing_op is not None:
265        existing_restore_ops.append(existing_op)
266        continue
268      if any(serialized_tensor.name.startswith(name)
269             for name in created_compat_names):
270        continue  # Saveable has already been created for this tensor.
272      # Only if we don't have cached ops for this SaveableObject, we'll see if
273      # the SaveableObject itself has been cached. If not, we'll make it, and
274      # either way we'll extract new ops from it (or if it has Python state to
275      # restore, we'll run that).
276      saveables_cache = self._checkpoint.saveables_cache
277      if saveables_cache is None:
278        # No SaveableObject caching when executing eagerly.
279        saveable = None
280      else:
281        # If we've already created and cached a SaveableObject for this
282        # attribute, we can re-use it to avoid re-creating some ops when graph
283        # building.
284        saveable_list = saveables_cache.get(self.trackable,
285                                            {}).get(serialized_tensor.name,
286                                                    (None,))
287        if len(saveable_list) == 1:
288          # Almost every attribute will have exactly one SaveableObject.
289          saveable, = saveable_list
290        else:
291          # Don't use cached SaveableObjects for partitioned variables, which is
292          # the only case where we'd have a list of SaveableObjects. Op caching
293          # will catch them.
294          saveable = None
295      if saveable is not None:
296        # The name of this attribute has changed, so we need to re-generate
297        # the SaveableObject.
298        if serialized_tensor.checkpoint_key not in saveable.name:
299          saveable = None
300          del saveables_cache[self.trackable]
301      if saveable is None:
302        # If there was no cached SaveableObject, create one.
303        # Use the name to check if the Python object has the same attribute.
304        saveable = _get_saveable_from_factory(saveable_factories,
305                                              serialized_tensor,
306                                              created_compat_names)
307        if saveable is None:
308          # Purposefully does not throw an exception if attributes have been
309          # added or deleted. Stores unused attributes so an exception can be
310          # raised if the user decides to check that everything in the
311          # checkpoint was loaded.
312          self._checkpoint.unused_attributes.setdefault(
313              self._proto_id, []).append(serialized_tensor.name)
314          continue
315        if saveables_cache is not None:
316          saveables_cache.setdefault(self.trackable,
317                                     {})[serialized_tensor.name] = [saveable]
318      named_saveables[serialized_tensor.checkpoint_key] = saveable
320    return existing_restore_ops, named_saveables
322  def restore_ops(self):
323    """Create or fetch restore ops for this object's attributes.
325    Requires that the `Trackable` Python object has been bound to an object
326    ID in the checkpoint.
328    Returns:
329      A list of operations when graph building, or an empty list when executing
330      eagerly.
331    """
332    if self._has_registered_saver():
333      raise ValueError("Unable to run individual checkpoint restore for objects"
334                       " with registered savers.")
335    (restore_ops, tensor_saveables, python_positions,
336     _) = self.gather_ops_or_named_saveables()
337    restore_ops.extend(
338        self._checkpoint.restore_saveables(tensor_saveables, python_positions))
339    return restore_ops
341  @property
342  def checkpoint(self):
343    return self._checkpoint
345  @property
346  def trackable(self):
347    return self._checkpoint.object_by_proto_id[self._proto_id]
349  @property
350  def object_proto(self):
351    return self._checkpoint.object_graph_proto.nodes[self._proto_id]
353  @property
354  def proto_id(self):
355    return self._proto_id
357  @property
358  def restore_uid(self):
359    return self._checkpoint.restore_uid
361  def __repr__(self):
362    return repr(self.object_proto)
364  def value_shape(self):
365    """The shape of the VARIABLE_VALUE tensor.
367    Returns:
368      If found a TensorShape object, otherwise None.
369    """
370    for serialized_tensor in self.object_proto.attributes:
371      if serialized_tensor.name == constants.VARIABLE_VALUE_KEY:
372        return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
373    return None
375  def _has_registered_saver(self):
376    return bool(self.object_proto.registered_saver.name)
378  def get_registered_saver_name(self):
379    """Returns the registered saver name defined in the Checkpoint."""
380    if self._has_registered_saver():
381      saver_name = self.object_proto.registered_saver.name
382      try:
383        registration.validate_restore_function(self.trackable, saver_name)
384      except ValueError as e:
385        if registration.get_strict_predicate_restore(saver_name):
386          raise e
387        self.skip_restore = True
388      return saver_name
389    return None
391  def create_slot_variable_position(self, optimizer_object, variable,
392                                    slot_variable_id, slot_name):
393    """Generates CheckpointPosition for a slot variable.
395    Args:
396      optimizer_object: Optimizer that owns the slot variable.
397      variable: Variable associated with the slot variable.
398      slot_variable_id: ID of the slot variable.
399      slot_name: Name of the slot variable.
401    Returns:
402      If there is a slot variable in the `optimizer_object` that has not been
403      bound to the checkpoint, this function returns a tuple of (
404        new `CheckpointPosition` for the slot variable,
405        the slot variable itself).
406    """
407    slot_variable_position = CheckpointPosition(
408        checkpoint=self.checkpoint, proto_id=slot_variable_id)
409    # pylint: disable=protected-access
410    slot_variable = optimizer_object._create_or_restore_slot_variable(
411        slot_variable_position=slot_variable_position,
412        variable=variable,
413        slot_name=slot_name)
414    # pylint: enable=protected-access
415    if (slot_variable is not None and
416        slot_variable_position.bind_object(slot_variable)):
417      return slot_variable_position, slot_variable
418    else:
419      return None, None
421  def create_child_position(self, node_id):
422    return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id)
424  def _restore_descendants(self):
425    """Restore the bound Trackable and dependencies (may be deferred)."""
426    # Attempt a breadth-first traversal, since presumably the user has more
427    # control over shorter paths. If we don't have all of the dependencies at
428    # this point, the end result is not breadth-first (since other deferred
429    # traversals will happen later).
431    # You may be wondering why elements in the `visit_queue` are tuples that
432    # contains both CheckpointPositions and their Trackable. The reason is that
433    # Optimizers will not keep a strong reference to slot vars for
434    # ShardedVariables. The slot variable must be kept in memory until the
435    # restore saveables have been created.
436    visit_queue = collections.deque([(self, self.trackable)])
437    restore_ops = []
438    tensor_saveables = {}
439    python_positions = []
440    registered_savers = collections.defaultdict(dict)
441    while visit_queue:
442      current_position, _ = visit_queue.popleft()
444      # Restore using the ops defined in a Saveable or registered function.
445      (new_restore_ops, new_tensor_saveables, new_python_positions,
446       new_registered_savers) = current_position._single_restore()  # pylint: disable=protected-access
447      restore_ops.extend(new_restore_ops)
448      tensor_saveables.update(new_tensor_saveables)
449      python_positions.extend(new_python_positions)
450      for saver_name, trackable_map in new_registered_savers.items():
451        registered_savers[saver_name].update(trackable_map)
453      # Pass the restoration to the dependencies.
454      _queue_children_for_restoration(current_position, visit_queue)
455      _queue_slot_variables(current_position, visit_queue)
457    restore_ops.extend(
458        current_position.checkpoint.restore_saveables(tensor_saveables,
459                                                      python_positions,
460                                                      registered_savers))
461    return restore_ops
463  def _single_restore(self):
464    """Restores the trackable."""
465    trackable = self.trackable
466    trackable._maybe_initialize_trackable()  # pylint: disable=protected-access
467    checkpoint = self.checkpoint
468    # If the UID of this restore is lower than our current update UID, we don't
469    # need to actually restore the object.
470    if checkpoint.restore_uid > trackable._update_uid:  # pylint: disable=protected-access
471      restore_ops, tensor_saveables, python_positions, registered_savers = (
472          self.gather_ops_or_named_saveables())
473      trackable._update_uid = checkpoint.restore_uid  # pylint: disable=protected-access
474    else:
475      restore_ops = ()
476      tensor_saveables = {}
477      python_positions = ()
478      registered_savers = {}
479    return restore_ops, tensor_saveables, python_positions, registered_savers
482def _queue_children_for_restoration(checkpoint_position, visit_queue):
483  """Queues the restoration of trackable's children or defers them."""
484  # pylint: disable=protected-access
485  trackable = checkpoint_position.trackable
486  for child in checkpoint_position.object_proto.children:
487    child_position = checkpoint_position.create_child_position(child.node_id)
488    local_object = trackable._lookup_dependency(child.local_name)
489    child_proto = child_position.object_proto
490    if local_object is None:
491      # We don't yet have a dependency registered with this name. Save it
492      # in case we do.
493      if child_proto.HasField("has_checkpoint_values"):
494        has_value = child_proto.has_checkpoint_values.value
495      else:
496        # If the field is not set, do a simple check to see if the dependency
497        # has children and/or checkpointed values.
498        has_value = bool(
499            child_proto.children or child_proto.attributes or
500            child_proto.slot_variables or
501            child_proto.HasField("registered_saver"))
502      if has_value:
503        trackable._deferred_dependencies.setdefault(child.local_name,
504                                                    []).append(child_position)
505    else:
506      if child_position.bind_object(trackable=local_object):
507        # This object's correspondence is new, so dependencies need to be
508        # visited. Delay doing it so that we get a breadth-first dependency
509        # resolution order (shallowest paths first). The caller is responsible
510        # for emptying visit_queue.
511        visit_queue.append((child_position, local_object))
514_DeferredSlotVariableRestoration = collections.namedtuple(
515    "_DeferredSlotVariableRestoration", [
516        "original_variable",
517        "slot_variable_id",
518        "slot_name",
519    ])
522def _queue_slot_variables(checkpoint_position, visit_queue):
523  """Queues slot variables for restoration."""
524  trackable = checkpoint_position.trackable
525  checkpoint = checkpoint_position.checkpoint
526  for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop(
527      checkpoint_position.proto_id, ())):
528    slot_variable_position, slot_variable = (
529        checkpoint_position.create_slot_variable_position(
530            trackable, deferred_slot_restoration.original_variable,
531            deferred_slot_restoration.slot_variable_id,
532            deferred_slot_restoration.slot_name))
533    if slot_variable_position is not None:
534      visit_queue.append((slot_variable_position, slot_variable))
535  for slot_restoration in checkpoint.slot_restorations.pop(
536      checkpoint_position.proto_id, ()):
537    optimizer_object = checkpoint.object_by_proto_id.get(
538        slot_restoration.optimizer_id, None)
539    if optimizer_object is None:
540      # The optimizer has not yet been created or tracked. Record in the
541      # checkpoint that the slot variables need to be restored when it is.
542      checkpoint.deferred_slot_restorations.setdefault(
543          slot_restoration.optimizer_id, []).append(
544              _DeferredSlotVariableRestoration(
545                  original_variable=trackable,
546                  slot_variable_id=slot_restoration.slot_variable_id,
547                  slot_name=slot_restoration.slot_name))
549    # `optimizer_object` can be a `Checkpoint` when user only needs the
550    # attributes the optimizer holds, such as `iterations`. In those cases,
551    # it would not have the optimizer's `_create_or_restore_slot_variable`
552    # method.
553    elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
554      slot_variable_position, slot_variable = (
555          checkpoint_position.create_slot_variable_position(
556              optimizer_object, trackable, slot_restoration.slot_variable_id,
557              slot_restoration.slot_name))
558      if slot_variable_position is not None:
559        visit_queue.append((slot_variable_position, slot_variable))
562def _extract_saveable_name(checkpoint_key):
563  # Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/"
564  search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/"
565  return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)]
568def _get_saveable_from_factory(saveable_factories, serialized_tensor,
569                               created_compat_names):
570  """Returns the saveable generated from the factory method."""
571  matched_factory = None
573  # The `expected_factory_name` is used to find the right saveable factory,
574  # while the `factory_input_name` is the value that is passed to the factory
575  # method to instantiate the SaveableObject.
576  expected_factory_name = serialized_tensor.name
577  factory_input_name = serialized_tensor.checkpoint_key
579  # Case 1: the name already exactly matches a key in saveable_factories.
580  if expected_factory_name in saveable_factories:
581    matched_factory = saveable_factories[expected_factory_name]
583  # Case 2: (Forward compat) The serialized name is composed of
584  # "factory_name" + "SUFFIX". Get the matching factory name.
585  if matched_factory is None:
587    for factory_name, factory in saveable_factories.items():
588      if expected_factory_name.startswith(factory_name):
589        if matched_factory is not None:
590          # This condition is met in the extreme edge case where the object
591          # returns two saveable factories with similar names. This is very
592          # unlikely because there zero objects inside TensorFlow that use
593          # more than one saveable factory.
594          raise ValueError("Forward compatibility load error: Unable to load "
595                           "checkpoint saved in future version of TensorFlow. "
596                           "Please update your version of TensorFlow to the "
597                           "version in which the checkpoint was saved.")
599        matched_factory = factory
600        factory_input_name = _extract_saveable_name(
601            serialized_tensor.checkpoint_key) + factory_name
602        created_compat_names.add(factory_name)
604  if callable(matched_factory):
605    return matched_factory(name=factory_input_name)
606  return matched_factory