• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logic for restoring checkpointed values for Trackables."""
16
17import collections
18
19from tensorflow.python.checkpoint import saveable_compat
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import gen_io_ops as io_ops
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.saved_model import registration
26from tensorflow.python.trackable import constants
27from tensorflow.python.trackable import python_state
28from tensorflow.python.trackable import trackable_utils
29
30
31class CheckpointPosition(object):
32  """Indicates a position within a `_CheckpointRestoreCoordinator`."""
33
34  __slots__ = ["_checkpoint", "_proto_id", "skip_restore"]
35
36  def __init__(self, checkpoint, proto_id):
37    """Specify an object within a checkpoint.
38
39    Args:
40      checkpoint: A _CheckpointRestoreCoordinator object.
41      proto_id: The index of this object in TrackableObjectGraph.nodes.
42    """
43    self._checkpoint = checkpoint
44    self._proto_id = proto_id
45    # This may be set to True if the registered saver cannot be used with this
46    # object.
47    self.skip_restore = False
48
49  def restore(self, trackable):
50    """Restore this value into `trackable`."""
51    with ops.init_scope():
52      if self.bind_object(trackable):
53        # This object's correspondence with a checkpointed object is new, so
54        # process deferred restorations for it and its dependencies.
55        restore_ops = self._restore_descendants()
56        if restore_ops:
57          self._checkpoint.new_restore_ops(restore_ops)
58
59  def bind_object(self, trackable):
60    """Set a checkpoint<->object correspondence.
61
62    Args:
63      trackable: The object to record a correspondence for.
64
65    Returns:
66      True if this is a new assignment, False if this object has already been
67      mapped to a checkpointed `Object` proto.
68    Raises:
69      AssertionError: If another object is already bound to the `Object` proto.
70    """
71    checkpoint = self.checkpoint
72    checkpoint.all_python_objects.add(trackable)
73    current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
74    checkpoint.matched_proto_ids.add(self._proto_id)
75    if current_assignment is None:
76      checkpoint.object_by_proto_id[self._proto_id] = trackable
77      return True  # New assignment
78    else:
79      # The object was already mapped for this checkpoint load, which means
80      # we don't need to do anything besides check that the mapping is
81      # consistent (if the dependency DAG is not a tree then there are
82      # multiple paths to the same object).
83      if current_assignment is not trackable:
84        logging.warning(
85            "Inconsistent references when loading the checkpoint into this "
86            "object graph. For example, in the saved checkpoint object, "
87            "`model.layer.weight` and `model.layer_copy.weight` reference the "
88            "same variable, while in the current object these are two different"
89            " variables. The referenced variables are:"
90            f"({current_assignment} and {trackable}).")
91      return False  # Not a new assignment
92
93  def is_simple_variable(self):
94    """Determine whether this value is restorable with a Tensor initializer."""
95    attributes = self.object_proto.attributes
96    return (len(attributes) == 1 and
97            attributes[0].name == constants.VARIABLE_VALUE_KEY and
98            not self.object_proto.children)
99
100  def value_tensors(self, shape_and_slices=None):
101    """Create value `Tensor`s for this object's attributes.
102
103    Does not require that the Python object has been created. Used for
104    restore-on-create when executing eagerly.
105
106    Args:
107      shape_and_slices: A dict mapping from object attribute names to a shape
108        and slice string that will be passed to a RestoreV2 op. If the dict is
109        None or if an object attribute is not in the dict, the full tensor will
110        be restored.
111
112    Returns:
113      A dictionary mapping from object attribute names to `Tensor`s.
114    """
115    value_tensors = {}
116    for serialized_tensor in self.object_proto.attributes:
117      checkpoint_key = serialized_tensor.checkpoint_key
118      dtype = self._checkpoint.dtype_map[checkpoint_key]
119      base_type = dtype.base_dtype
120      io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
121      with ops.init_scope():
122        with ops.device(io_device):
123          # Run the restore itself on the io_device(CPU or specified).
124          if (shape_and_slices is not None and
125              serialized_tensor.name in shape_and_slices):
126            shape_and_slice = shape_and_slices[serialized_tensor.name]
127          else:
128            shape_and_slice = ""
129          value, = io_ops.restore_v2(
130              prefix=self._checkpoint.save_path_tensor,
131              tensor_names=[checkpoint_key],
132              shape_and_slices=[shape_and_slice],
133              dtypes=[base_type],
134              name="%s_checkpoint_read" % (serialized_tensor.name,))
135        # Copy the value to the current device if necessary.
136        value_tensors[serialized_tensor.name] = array_ops.identity(value)
137    return value_tensors
138
139  def gather_ops_or_named_saveables(self):
140    """Looks up or creates SaveableObjects which don't have cached ops.
141
142    Returns:
143      A tuple of (
144          existing_restore_ops: list,
145          named_saveables: dict,
146          python_positions: list,
147          registered_savers: dict)
148    """
149    # pylint:disable=g-import-not-at-top
150    # There are circular dependencies between Trackable and SaveableObject,
151    # so we must import it here.
152    # TODO(b/224069573): Remove this code from Trackable.
153    from tensorflow.python.training.saving import saveable_object_util
154    # pylint:enable=g-import-not-at-top
155
156    recorded_registered_saver = self.get_registered_saver_name()
157    if not (self.object_proto.attributes or recorded_registered_saver):
158      return [], {}, [], {}
159
160    existing_restore_ops = []
161    named_saveables = {}
162    python_positions = []
163    registered_savers = collections.defaultdict(dict)
164
165    saveable_factories = saveable_object_util.saveable_objects_from_trackable(
166        self.trackable)
167    saver_name = registration.get_registered_saver_name(self.trackable)
168
169    if recorded_registered_saver:
170      if not self.skip_restore:
171        name = self.object_proto.registered_saver.object_name
172        registered_savers[recorded_registered_saver][name] = self.trackable
173      # Else: Skip restoration of this Trackable. This skip only happens if the
174      # registered saver has enabled `option_restore`. Otherwise, an error would
175      # have been raised at `self.get_registered_saver_name()`.
176    elif saver_name:
177      # In this case, the checkpoint has a recorded serialized tensor but no
178      # registered saver, while the Trackable loading the checkpoint has
179      # migrated to the registered checkpoint functionality (TPUEmbedding is an
180      # example of this).
181
182      # Set the Trackable's object name to the first checkpoint key that is
183      # stored in checkpoint. If there is a use case that requires the other
184      # keys, then we can take another look at this.
185      registered_savers[saver_name] = {
186          self.object_proto.attributes[0].checkpoint_key: self.trackable
187      }
188    elif isinstance(self.trackable, python_state.PythonState):
189      python_positions.append(self)
190    elif saveable_factories.keys() == {
191        trackable_utils.SERIALIZE_TO_TENSORS_NAME
192    }:
193      existing_restore_ops, named_saveables = (
194          self._create_serialize_to_tensor_saveable(saveable_factories))
195    elif saveable_factories:
196      existing_restore_ops, named_saveables = (
197          self._create_saveables_by_attribute_name(saveable_factories))
198    else:
199      # If no registered savers were found, then it means that one or more
200      # serialized tensors were never used.
201      for serialized_tensor in self.object_proto.attributes:
202        self._checkpoint.unused_attributes.setdefault(
203            self._proto_id, []).append(serialized_tensor.name)
204    return (existing_restore_ops, named_saveables, python_positions,
205            registered_savers)
206
207  def _create_serialize_to_tensor_saveable(self, saveable_factories):
208    """Creates a saveable using the _serialize_to_tensor method."""
209    # Extract the saveable name from the checkpoint key. This will be used as
210    # the cache key or the name to pass to the saveable factory.
211    suffix = saveable_compat.get_saveable_name(self.trackable) or ""
212    saveable_name = _extract_saveable_name(
213        self.object_proto.attributes[0].checkpoint_key) + suffix
214
215    # Try to find the cached saveable (only in graph mode).
216    if not context.executing_eagerly():
217      existing_op = self._checkpoint.restore_ops_by_name.get(
218          saveable_name, None)
219      if existing_op is not None:
220        return existing_op, {}
221
222      saveables_cache = self._checkpoint.saveables_cache.setdefault(
223          self.trackable, {})
224      if saveable_name in saveables_cache:
225        return [], {saveable_name: saveables_cache[saveable_name]}
226
227    saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME](
228        name=saveable_name)
229    if not context.executing_eagerly():
230      saveables_cache[saveable_name] = saveable
231    return [], {saveable_name: saveable}
232
233  def _create_saveables_by_attribute_name(self, saveable_factories):
234    """Creates or caches SaveableObjects by matching the attribute names.
235
236    The attribute name keys in the `saveable_factories` is used to find the
237    corresponding attribute in the object proto. Attributes contain checkpoint
238    keys which are passed to the factory function to generate the
239    SaveableObject.
240
241    Args:
242      saveable_factories: a dict mapping attribute name to a callable factory
243        function that produces a SaveableObject.
244
245    Returns:
246      A tuple of (
247          existing_restore_ops: list,
248          named_saveables: dict)
249    """
250    # Name saveables based on the name this object had when it was checkpointed.
251    named_saveables = {}
252    existing_restore_ops = []
253
254    # Forward compatibility code: when loading a future checkpoint, there may
255    # be multiple SerializedTensors mapped to a single saveable.
256    created_compat_names = set()
257
258    for serialized_tensor in self.object_proto.attributes:
259      if context.executing_eagerly():
260        existing_op = None
261      else:
262        existing_op = self._checkpoint.restore_ops_by_name.get(
263            serialized_tensor.checkpoint_key, None)
264      if existing_op is not None:
265        existing_restore_ops.append(existing_op)
266        continue
267
268      if any(serialized_tensor.name.startswith(name)
269             for name in created_compat_names):
270        continue  # Saveable has already been created for this tensor.
271
272      # Only if we don't have cached ops for this SaveableObject, we'll see if
273      # the SaveableObject itself has been cached. If not, we'll make it, and
274      # either way we'll extract new ops from it (or if it has Python state to
275      # restore, we'll run that).
276      saveables_cache = self._checkpoint.saveables_cache
277      if saveables_cache is None:
278        # No SaveableObject caching when executing eagerly.
279        saveable = None
280      else:
281        # If we've already created and cached a SaveableObject for this
282        # attribute, we can re-use it to avoid re-creating some ops when graph
283        # building.
284        saveable_list = saveables_cache.get(self.trackable,
285                                            {}).get(serialized_tensor.name,
286                                                    (None,))
287        if len(saveable_list) == 1:
288          # Almost every attribute will have exactly one SaveableObject.
289          saveable, = saveable_list
290        else:
291          # Don't use cached SaveableObjects for partitioned variables, which is
292          # the only case where we'd have a list of SaveableObjects. Op caching
293          # will catch them.
294          saveable = None
295      if saveable is not None:
296        # The name of this attribute has changed, so we need to re-generate
297        # the SaveableObject.
298        if serialized_tensor.checkpoint_key not in saveable.name:
299          saveable = None
300          del saveables_cache[self.trackable]
301      if saveable is None:
302        # If there was no cached SaveableObject, create one.
303        # Use the name to check if the Python object has the same attribute.
304        saveable = _get_saveable_from_factory(saveable_factories,
305                                              serialized_tensor,
306                                              created_compat_names)
307        if saveable is None:
308          # Purposefully does not throw an exception if attributes have been
309          # added or deleted. Stores unused attributes so an exception can be
310          # raised if the user decides to check that everything in the
311          # checkpoint was loaded.
312          self._checkpoint.unused_attributes.setdefault(
313              self._proto_id, []).append(serialized_tensor.name)
314          continue
315        if saveables_cache is not None:
316          saveables_cache.setdefault(self.trackable,
317                                     {})[serialized_tensor.name] = [saveable]
318      named_saveables[serialized_tensor.checkpoint_key] = saveable
319
320    return existing_restore_ops, named_saveables
321
322  def restore_ops(self):
323    """Create or fetch restore ops for this object's attributes.
324
325    Requires that the `Trackable` Python object has been bound to an object
326    ID in the checkpoint.
327
328    Returns:
329      A list of operations when graph building, or an empty list when executing
330      eagerly.
331    """
332    if self._has_registered_saver():
333      raise ValueError("Unable to run individual checkpoint restore for objects"
334                       " with registered savers.")
335    (restore_ops, tensor_saveables, python_positions,
336     _) = self.gather_ops_or_named_saveables()
337    restore_ops.extend(
338        self._checkpoint.restore_saveables(tensor_saveables, python_positions))
339    return restore_ops
340
341  @property
342  def checkpoint(self):
343    return self._checkpoint
344
345  @property
346  def trackable(self):
347    return self._checkpoint.object_by_proto_id[self._proto_id]
348
349  @property
350  def object_proto(self):
351    return self._checkpoint.object_graph_proto.nodes[self._proto_id]
352
353  @property
354  def proto_id(self):
355    return self._proto_id
356
357  @property
358  def restore_uid(self):
359    return self._checkpoint.restore_uid
360
361  def __repr__(self):
362    return repr(self.object_proto)
363
364  def value_shape(self):
365    """The shape of the VARIABLE_VALUE tensor.
366
367    Returns:
368      If found a TensorShape object, otherwise None.
369    """
370    for serialized_tensor in self.object_proto.attributes:
371      if serialized_tensor.name == constants.VARIABLE_VALUE_KEY:
372        return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
373    return None
374
375  def _has_registered_saver(self):
376    return bool(self.object_proto.registered_saver.name)
377
378  def get_registered_saver_name(self):
379    """Returns the registered saver name defined in the Checkpoint."""
380    if self._has_registered_saver():
381      saver_name = self.object_proto.registered_saver.name
382      try:
383        registration.validate_restore_function(self.trackable, saver_name)
384      except ValueError as e:
385        if registration.get_strict_predicate_restore(saver_name):
386          raise e
387        self.skip_restore = True
388      return saver_name
389    return None
390
391  def create_slot_variable_position(self, optimizer_object, variable,
392                                    slot_variable_id, slot_name):
393    """Generates CheckpointPosition for a slot variable.
394
395    Args:
396      optimizer_object: Optimizer that owns the slot variable.
397      variable: Variable associated with the slot variable.
398      slot_variable_id: ID of the slot variable.
399      slot_name: Name of the slot variable.
400
401    Returns:
402      If there is a slot variable in the `optimizer_object` that has not been
403      bound to the checkpoint, this function returns a tuple of (
404        new `CheckpointPosition` for the slot variable,
405        the slot variable itself).
406    """
407    slot_variable_position = CheckpointPosition(
408        checkpoint=self.checkpoint, proto_id=slot_variable_id)
409    # pylint: disable=protected-access
410    slot_variable = optimizer_object._create_or_restore_slot_variable(
411        slot_variable_position=slot_variable_position,
412        variable=variable,
413        slot_name=slot_name)
414    # pylint: enable=protected-access
415    if (slot_variable is not None and
416        slot_variable_position.bind_object(slot_variable)):
417      return slot_variable_position, slot_variable
418    else:
419      return None, None
420
421  def create_child_position(self, node_id):
422    return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id)
423
424  def _restore_descendants(self):
425    """Restore the bound Trackable and dependencies (may be deferred)."""
426    # Attempt a breadth-first traversal, since presumably the user has more
427    # control over shorter paths. If we don't have all of the dependencies at
428    # this point, the end result is not breadth-first (since other deferred
429    # traversals will happen later).
430
431    # You may be wondering why elements in the `visit_queue` are tuples that
432    # contains both CheckpointPositions and their Trackable. The reason is that
433    # Optimizers will not keep a strong reference to slot vars for
434    # ShardedVariables. The slot variable must be kept in memory until the
435    # restore saveables have been created.
436    visit_queue = collections.deque([(self, self.trackable)])
437    restore_ops = []
438    tensor_saveables = {}
439    python_positions = []
440    registered_savers = collections.defaultdict(dict)
441    while visit_queue:
442      current_position, _ = visit_queue.popleft()
443
444      # Restore using the ops defined in a Saveable or registered function.
445      (new_restore_ops, new_tensor_saveables, new_python_positions,
446       new_registered_savers) = current_position._single_restore()  # pylint: disable=protected-access
447      restore_ops.extend(new_restore_ops)
448      tensor_saveables.update(new_tensor_saveables)
449      python_positions.extend(new_python_positions)
450      for saver_name, trackable_map in new_registered_savers.items():
451        registered_savers[saver_name].update(trackable_map)
452
453      # Pass the restoration to the dependencies.
454      _queue_children_for_restoration(current_position, visit_queue)
455      _queue_slot_variables(current_position, visit_queue)
456
457    restore_ops.extend(
458        current_position.checkpoint.restore_saveables(tensor_saveables,
459                                                      python_positions,
460                                                      registered_savers))
461    return restore_ops
462
463  def _single_restore(self):
464    """Restores the trackable."""
465    trackable = self.trackable
466    trackable._maybe_initialize_trackable()  # pylint: disable=protected-access
467    checkpoint = self.checkpoint
468    # If the UID of this restore is lower than our current update UID, we don't
469    # need to actually restore the object.
470    if checkpoint.restore_uid > trackable._update_uid:  # pylint: disable=protected-access
471      restore_ops, tensor_saveables, python_positions, registered_savers = (
472          self.gather_ops_or_named_saveables())
473      trackable._update_uid = checkpoint.restore_uid  # pylint: disable=protected-access
474    else:
475      restore_ops = ()
476      tensor_saveables = {}
477      python_positions = ()
478      registered_savers = {}
479    return restore_ops, tensor_saveables, python_positions, registered_savers
480
481
482def _queue_children_for_restoration(checkpoint_position, visit_queue):
483  """Queues the restoration of trackable's children or defers them."""
484  # pylint: disable=protected-access
485  trackable = checkpoint_position.trackable
486  for child in checkpoint_position.object_proto.children:
487    child_position = checkpoint_position.create_child_position(child.node_id)
488    local_object = trackable._lookup_dependency(child.local_name)
489    child_proto = child_position.object_proto
490    if local_object is None:
491      # We don't yet have a dependency registered with this name. Save it
492      # in case we do.
493      if child_proto.HasField("has_checkpoint_values"):
494        has_value = child_proto.has_checkpoint_values.value
495      else:
496        # If the field is not set, do a simple check to see if the dependency
497        # has children and/or checkpointed values.
498        has_value = bool(
499            child_proto.children or child_proto.attributes or
500            child_proto.slot_variables or
501            child_proto.HasField("registered_saver"))
502      if has_value:
503        trackable._deferred_dependencies.setdefault(child.local_name,
504                                                    []).append(child_position)
505    else:
506      if child_position.bind_object(trackable=local_object):
507        # This object's correspondence is new, so dependencies need to be
508        # visited. Delay doing it so that we get a breadth-first dependency
509        # resolution order (shallowest paths first). The caller is responsible
510        # for emptying visit_queue.
511        visit_queue.append((child_position, local_object))
512
513
514_DeferredSlotVariableRestoration = collections.namedtuple(
515    "_DeferredSlotVariableRestoration", [
516        "original_variable",
517        "slot_variable_id",
518        "slot_name",
519    ])
520
521
522def _queue_slot_variables(checkpoint_position, visit_queue):
523  """Queues slot variables for restoration."""
524  trackable = checkpoint_position.trackable
525  checkpoint = checkpoint_position.checkpoint
526  for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop(
527      checkpoint_position.proto_id, ())):
528    slot_variable_position, slot_variable = (
529        checkpoint_position.create_slot_variable_position(
530            trackable, deferred_slot_restoration.original_variable,
531            deferred_slot_restoration.slot_variable_id,
532            deferred_slot_restoration.slot_name))
533    if slot_variable_position is not None:
534      visit_queue.append((slot_variable_position, slot_variable))
535  for slot_restoration in checkpoint.slot_restorations.pop(
536      checkpoint_position.proto_id, ()):
537    optimizer_object = checkpoint.object_by_proto_id.get(
538        slot_restoration.optimizer_id, None)
539    if optimizer_object is None:
540      # The optimizer has not yet been created or tracked. Record in the
541      # checkpoint that the slot variables need to be restored when it is.
542      checkpoint.deferred_slot_restorations.setdefault(
543          slot_restoration.optimizer_id, []).append(
544              _DeferredSlotVariableRestoration(
545                  original_variable=trackable,
546                  slot_variable_id=slot_restoration.slot_variable_id,
547                  slot_name=slot_restoration.slot_name))
548
549    # `optimizer_object` can be a `Checkpoint` when user only needs the
550    # attributes the optimizer holds, such as `iterations`. In those cases,
551    # it would not have the optimizer's `_create_or_restore_slot_variable`
552    # method.
553    elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
554      slot_variable_position, slot_variable = (
555          checkpoint_position.create_slot_variable_position(
556              optimizer_object, trackable, slot_restoration.slot_variable_id,
557              slot_restoration.slot_name))
558      if slot_variable_position is not None:
559        visit_queue.append((slot_variable_position, slot_variable))
560
561
562def _extract_saveable_name(checkpoint_key):
563  # Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/"
564  search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/"
565  return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)]
566
567
568def _get_saveable_from_factory(saveable_factories, serialized_tensor,
569                               created_compat_names):
570  """Returns the saveable generated from the factory method."""
571  matched_factory = None
572
573  # The `expected_factory_name` is used to find the right saveable factory,
574  # while the `factory_input_name` is the value that is passed to the factory
575  # method to instantiate the SaveableObject.
576  expected_factory_name = serialized_tensor.name
577  factory_input_name = serialized_tensor.checkpoint_key
578
579  # Case 1: the name already exactly matches a key in saveable_factories.
580  if expected_factory_name in saveable_factories:
581    matched_factory = saveable_factories[expected_factory_name]
582
583  # Case 2: (Forward compat) The serialized name is composed of
584  # "factory_name" + "SUFFIX". Get the matching factory name.
585  if matched_factory is None:
586
587    for factory_name, factory in saveable_factories.items():
588      if expected_factory_name.startswith(factory_name):
589        if matched_factory is not None:
590          # This condition is met in the extreme edge case where the object
591          # returns two saveable factories with similar names. This is very
592          # unlikely because there zero objects inside TensorFlow that use
593          # more than one saveable factory.
594          raise ValueError("Forward compatibility load error: Unable to load "
595                           "checkpoint saved in future version of TensorFlow. "
596                           "Please update your version of TensorFlow to the "
597                           "version in which the checkpoint was saved.")
598
599        matched_factory = factory
600        factory_input_name = _extract_saveable_name(
601            serialized_tensor.checkpoint_key) + factory_name
602        created_compat_names.add(factory_name)
603
604  if callable(matched_factory):
605    return matched_factory(name=factory_input_name)
606  return matched_factory
607