1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logic for restoring checkpointed values for Trackables.""" 16 17import collections 18 19from tensorflow.python.checkpoint import saveable_compat 20from tensorflow.python.eager import context 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import gen_io_ops as io_ops 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.saved_model import registration 26from tensorflow.python.trackable import constants 27from tensorflow.python.trackable import python_state 28from tensorflow.python.trackable import trackable_utils 29 30 31class CheckpointPosition(object): 32 """Indicates a position within a `_CheckpointRestoreCoordinator`.""" 33 34 __slots__ = ["_checkpoint", "_proto_id", "skip_restore"] 35 36 def __init__(self, checkpoint, proto_id): 37 """Specify an object within a checkpoint. 38 39 Args: 40 checkpoint: A _CheckpointRestoreCoordinator object. 41 proto_id: The index of this object in TrackableObjectGraph.nodes. 42 """ 43 self._checkpoint = checkpoint 44 self._proto_id = proto_id 45 # This may be set to True if the registered saver cannot be used with this 46 # object. 47 self.skip_restore = False 48 49 def restore(self, trackable): 50 """Restore this value into `trackable`.""" 51 with ops.init_scope(): 52 if self.bind_object(trackable): 53 # This object's correspondence with a checkpointed object is new, so 54 # process deferred restorations for it and its dependencies. 55 restore_ops = self._restore_descendants() 56 if restore_ops: 57 self._checkpoint.new_restore_ops(restore_ops) 58 59 def bind_object(self, trackable): 60 """Set a checkpoint<->object correspondence. 61 62 Args: 63 trackable: The object to record a correspondence for. 64 65 Returns: 66 True if this is a new assignment, False if this object has already been 67 mapped to a checkpointed `Object` proto. 68 Raises: 69 AssertionError: If another object is already bound to the `Object` proto. 70 """ 71 checkpoint = self.checkpoint 72 checkpoint.all_python_objects.add(trackable) 73 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) 74 checkpoint.matched_proto_ids.add(self._proto_id) 75 if current_assignment is None: 76 checkpoint.object_by_proto_id[self._proto_id] = trackable 77 return True # New assignment 78 else: 79 # The object was already mapped for this checkpoint load, which means 80 # we don't need to do anything besides check that the mapping is 81 # consistent (if the dependency DAG is not a tree then there are 82 # multiple paths to the same object). 83 if current_assignment is not trackable: 84 logging.warning( 85 "Inconsistent references when loading the checkpoint into this " 86 "object graph. For example, in the saved checkpoint object, " 87 "`model.layer.weight` and `model.layer_copy.weight` reference the " 88 "same variable, while in the current object these are two different" 89 " variables. The referenced variables are:" 90 f"({current_assignment} and {trackable}).") 91 return False # Not a new assignment 92 93 def is_simple_variable(self): 94 """Determine whether this value is restorable with a Tensor initializer.""" 95 attributes = self.object_proto.attributes 96 return (len(attributes) == 1 and 97 attributes[0].name == constants.VARIABLE_VALUE_KEY and 98 not self.object_proto.children) 99 100 def value_tensors(self, shape_and_slices=None): 101 """Create value `Tensor`s for this object's attributes. 102 103 Does not require that the Python object has been created. Used for 104 restore-on-create when executing eagerly. 105 106 Args: 107 shape_and_slices: A dict mapping from object attribute names to a shape 108 and slice string that will be passed to a RestoreV2 op. If the dict is 109 None or if an object attribute is not in the dict, the full tensor will 110 be restored. 111 112 Returns: 113 A dictionary mapping from object attribute names to `Tensor`s. 114 """ 115 value_tensors = {} 116 for serialized_tensor in self.object_proto.attributes: 117 checkpoint_key = serialized_tensor.checkpoint_key 118 dtype = self._checkpoint.dtype_map[checkpoint_key] 119 base_type = dtype.base_dtype 120 io_device = self._checkpoint.options.experimental_io_device or "cpu:0" 121 with ops.init_scope(): 122 with ops.device(io_device): 123 # Run the restore itself on the io_device(CPU or specified). 124 if (shape_and_slices is not None and 125 serialized_tensor.name in shape_and_slices): 126 shape_and_slice = shape_and_slices[serialized_tensor.name] 127 else: 128 shape_and_slice = "" 129 value, = io_ops.restore_v2( 130 prefix=self._checkpoint.save_path_tensor, 131 tensor_names=[checkpoint_key], 132 shape_and_slices=[shape_and_slice], 133 dtypes=[base_type], 134 name="%s_checkpoint_read" % (serialized_tensor.name,)) 135 # Copy the value to the current device if necessary. 136 value_tensors[serialized_tensor.name] = array_ops.identity(value) 137 return value_tensors 138 139 def gather_ops_or_named_saveables(self): 140 """Looks up or creates SaveableObjects which don't have cached ops. 141 142 Returns: 143 A tuple of ( 144 existing_restore_ops: list, 145 named_saveables: dict, 146 python_positions: list, 147 registered_savers: dict) 148 """ 149 # pylint:disable=g-import-not-at-top 150 # There are circular dependencies between Trackable and SaveableObject, 151 # so we must import it here. 152 # TODO(b/224069573): Remove this code from Trackable. 153 from tensorflow.python.training.saving import saveable_object_util 154 # pylint:enable=g-import-not-at-top 155 156 recorded_registered_saver = self.get_registered_saver_name() 157 if not (self.object_proto.attributes or recorded_registered_saver): 158 return [], {}, [], {} 159 160 existing_restore_ops = [] 161 named_saveables = {} 162 python_positions = [] 163 registered_savers = collections.defaultdict(dict) 164 165 saveable_factories = saveable_object_util.saveable_objects_from_trackable( 166 self.trackable) 167 saver_name = registration.get_registered_saver_name(self.trackable) 168 169 if recorded_registered_saver: 170 if not self.skip_restore: 171 name = self.object_proto.registered_saver.object_name 172 registered_savers[recorded_registered_saver][name] = self.trackable 173 # Else: Skip restoration of this Trackable. This skip only happens if the 174 # registered saver has enabled `option_restore`. Otherwise, an error would 175 # have been raised at `self.get_registered_saver_name()`. 176 elif saver_name: 177 # In this case, the checkpoint has a recorded serialized tensor but no 178 # registered saver, while the Trackable loading the checkpoint has 179 # migrated to the registered checkpoint functionality (TPUEmbedding is an 180 # example of this). 181 182 # Set the Trackable's object name to the first checkpoint key that is 183 # stored in checkpoint. If there is a use case that requires the other 184 # keys, then we can take another look at this. 185 registered_savers[saver_name] = { 186 self.object_proto.attributes[0].checkpoint_key: self.trackable 187 } 188 elif isinstance(self.trackable, python_state.PythonState): 189 python_positions.append(self) 190 elif saveable_factories.keys() == { 191 trackable_utils.SERIALIZE_TO_TENSORS_NAME 192 }: 193 existing_restore_ops, named_saveables = ( 194 self._create_serialize_to_tensor_saveable(saveable_factories)) 195 elif saveable_factories: 196 existing_restore_ops, named_saveables = ( 197 self._create_saveables_by_attribute_name(saveable_factories)) 198 else: 199 # If no registered savers were found, then it means that one or more 200 # serialized tensors were never used. 201 for serialized_tensor in self.object_proto.attributes: 202 self._checkpoint.unused_attributes.setdefault( 203 self._proto_id, []).append(serialized_tensor.name) 204 return (existing_restore_ops, named_saveables, python_positions, 205 registered_savers) 206 207 def _create_serialize_to_tensor_saveable(self, saveable_factories): 208 """Creates a saveable using the _serialize_to_tensor method.""" 209 # Extract the saveable name from the checkpoint key. This will be used as 210 # the cache key or the name to pass to the saveable factory. 211 suffix = saveable_compat.get_saveable_name(self.trackable) or "" 212 saveable_name = _extract_saveable_name( 213 self.object_proto.attributes[0].checkpoint_key) + suffix 214 215 # Try to find the cached saveable (only in graph mode). 216 if not context.executing_eagerly(): 217 existing_op = self._checkpoint.restore_ops_by_name.get( 218 saveable_name, None) 219 if existing_op is not None: 220 return existing_op, {} 221 222 saveables_cache = self._checkpoint.saveables_cache.setdefault( 223 self.trackable, {}) 224 if saveable_name in saveables_cache: 225 return [], {saveable_name: saveables_cache[saveable_name]} 226 227 saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME]( 228 name=saveable_name) 229 if not context.executing_eagerly(): 230 saveables_cache[saveable_name] = saveable 231 return [], {saveable_name: saveable} 232 233 def _create_saveables_by_attribute_name(self, saveable_factories): 234 """Creates or caches SaveableObjects by matching the attribute names. 235 236 The attribute name keys in the `saveable_factories` is used to find the 237 corresponding attribute in the object proto. Attributes contain checkpoint 238 keys which are passed to the factory function to generate the 239 SaveableObject. 240 241 Args: 242 saveable_factories: a dict mapping attribute name to a callable factory 243 function that produces a SaveableObject. 244 245 Returns: 246 A tuple of ( 247 existing_restore_ops: list, 248 named_saveables: dict) 249 """ 250 # Name saveables based on the name this object had when it was checkpointed. 251 named_saveables = {} 252 existing_restore_ops = [] 253 254 # Forward compatibility code: when loading a future checkpoint, there may 255 # be multiple SerializedTensors mapped to a single saveable. 256 created_compat_names = set() 257 258 for serialized_tensor in self.object_proto.attributes: 259 if context.executing_eagerly(): 260 existing_op = None 261 else: 262 existing_op = self._checkpoint.restore_ops_by_name.get( 263 serialized_tensor.checkpoint_key, None) 264 if existing_op is not None: 265 existing_restore_ops.append(existing_op) 266 continue 267 268 if any(serialized_tensor.name.startswith(name) 269 for name in created_compat_names): 270 continue # Saveable has already been created for this tensor. 271 272 # Only if we don't have cached ops for this SaveableObject, we'll see if 273 # the SaveableObject itself has been cached. If not, we'll make it, and 274 # either way we'll extract new ops from it (or if it has Python state to 275 # restore, we'll run that). 276 saveables_cache = self._checkpoint.saveables_cache 277 if saveables_cache is None: 278 # No SaveableObject caching when executing eagerly. 279 saveable = None 280 else: 281 # If we've already created and cached a SaveableObject for this 282 # attribute, we can re-use it to avoid re-creating some ops when graph 283 # building. 284 saveable_list = saveables_cache.get(self.trackable, 285 {}).get(serialized_tensor.name, 286 (None,)) 287 if len(saveable_list) == 1: 288 # Almost every attribute will have exactly one SaveableObject. 289 saveable, = saveable_list 290 else: 291 # Don't use cached SaveableObjects for partitioned variables, which is 292 # the only case where we'd have a list of SaveableObjects. Op caching 293 # will catch them. 294 saveable = None 295 if saveable is not None: 296 # The name of this attribute has changed, so we need to re-generate 297 # the SaveableObject. 298 if serialized_tensor.checkpoint_key not in saveable.name: 299 saveable = None 300 del saveables_cache[self.trackable] 301 if saveable is None: 302 # If there was no cached SaveableObject, create one. 303 # Use the name to check if the Python object has the same attribute. 304 saveable = _get_saveable_from_factory(saveable_factories, 305 serialized_tensor, 306 created_compat_names) 307 if saveable is None: 308 # Purposefully does not throw an exception if attributes have been 309 # added or deleted. Stores unused attributes so an exception can be 310 # raised if the user decides to check that everything in the 311 # checkpoint was loaded. 312 self._checkpoint.unused_attributes.setdefault( 313 self._proto_id, []).append(serialized_tensor.name) 314 continue 315 if saveables_cache is not None: 316 saveables_cache.setdefault(self.trackable, 317 {})[serialized_tensor.name] = [saveable] 318 named_saveables[serialized_tensor.checkpoint_key] = saveable 319 320 return existing_restore_ops, named_saveables 321 322 def restore_ops(self): 323 """Create or fetch restore ops for this object's attributes. 324 325 Requires that the `Trackable` Python object has been bound to an object 326 ID in the checkpoint. 327 328 Returns: 329 A list of operations when graph building, or an empty list when executing 330 eagerly. 331 """ 332 if self._has_registered_saver(): 333 raise ValueError("Unable to run individual checkpoint restore for objects" 334 " with registered savers.") 335 (restore_ops, tensor_saveables, python_positions, 336 _) = self.gather_ops_or_named_saveables() 337 restore_ops.extend( 338 self._checkpoint.restore_saveables(tensor_saveables, python_positions)) 339 return restore_ops 340 341 @property 342 def checkpoint(self): 343 return self._checkpoint 344 345 @property 346 def trackable(self): 347 return self._checkpoint.object_by_proto_id[self._proto_id] 348 349 @property 350 def object_proto(self): 351 return self._checkpoint.object_graph_proto.nodes[self._proto_id] 352 353 @property 354 def proto_id(self): 355 return self._proto_id 356 357 @property 358 def restore_uid(self): 359 return self._checkpoint.restore_uid 360 361 def __repr__(self): 362 return repr(self.object_proto) 363 364 def value_shape(self): 365 """The shape of the VARIABLE_VALUE tensor. 366 367 Returns: 368 If found a TensorShape object, otherwise None. 369 """ 370 for serialized_tensor in self.object_proto.attributes: 371 if serialized_tensor.name == constants.VARIABLE_VALUE_KEY: 372 return self._checkpoint.shape_map[serialized_tensor.checkpoint_key] 373 return None 374 375 def _has_registered_saver(self): 376 return bool(self.object_proto.registered_saver.name) 377 378 def get_registered_saver_name(self): 379 """Returns the registered saver name defined in the Checkpoint.""" 380 if self._has_registered_saver(): 381 saver_name = self.object_proto.registered_saver.name 382 try: 383 registration.validate_restore_function(self.trackable, saver_name) 384 except ValueError as e: 385 if registration.get_strict_predicate_restore(saver_name): 386 raise e 387 self.skip_restore = True 388 return saver_name 389 return None 390 391 def create_slot_variable_position(self, optimizer_object, variable, 392 slot_variable_id, slot_name): 393 """Generates CheckpointPosition for a slot variable. 394 395 Args: 396 optimizer_object: Optimizer that owns the slot variable. 397 variable: Variable associated with the slot variable. 398 slot_variable_id: ID of the slot variable. 399 slot_name: Name of the slot variable. 400 401 Returns: 402 If there is a slot variable in the `optimizer_object` that has not been 403 bound to the checkpoint, this function returns a tuple of ( 404 new `CheckpointPosition` for the slot variable, 405 the slot variable itself). 406 """ 407 slot_variable_position = CheckpointPosition( 408 checkpoint=self.checkpoint, proto_id=slot_variable_id) 409 # pylint: disable=protected-access 410 slot_variable = optimizer_object._create_or_restore_slot_variable( 411 slot_variable_position=slot_variable_position, 412 variable=variable, 413 slot_name=slot_name) 414 # pylint: enable=protected-access 415 if (slot_variable is not None and 416 slot_variable_position.bind_object(slot_variable)): 417 return slot_variable_position, slot_variable 418 else: 419 return None, None 420 421 def create_child_position(self, node_id): 422 return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id) 423 424 def _restore_descendants(self): 425 """Restore the bound Trackable and dependencies (may be deferred).""" 426 # Attempt a breadth-first traversal, since presumably the user has more 427 # control over shorter paths. If we don't have all of the dependencies at 428 # this point, the end result is not breadth-first (since other deferred 429 # traversals will happen later). 430 431 # You may be wondering why elements in the `visit_queue` are tuples that 432 # contains both CheckpointPositions and their Trackable. The reason is that 433 # Optimizers will not keep a strong reference to slot vars for 434 # ShardedVariables. The slot variable must be kept in memory until the 435 # restore saveables have been created. 436 visit_queue = collections.deque([(self, self.trackable)]) 437 restore_ops = [] 438 tensor_saveables = {} 439 python_positions = [] 440 registered_savers = collections.defaultdict(dict) 441 while visit_queue: 442 current_position, _ = visit_queue.popleft() 443 444 # Restore using the ops defined in a Saveable or registered function. 445 (new_restore_ops, new_tensor_saveables, new_python_positions, 446 new_registered_savers) = current_position._single_restore() # pylint: disable=protected-access 447 restore_ops.extend(new_restore_ops) 448 tensor_saveables.update(new_tensor_saveables) 449 python_positions.extend(new_python_positions) 450 for saver_name, trackable_map in new_registered_savers.items(): 451 registered_savers[saver_name].update(trackable_map) 452 453 # Pass the restoration to the dependencies. 454 _queue_children_for_restoration(current_position, visit_queue) 455 _queue_slot_variables(current_position, visit_queue) 456 457 restore_ops.extend( 458 current_position.checkpoint.restore_saveables(tensor_saveables, 459 python_positions, 460 registered_savers)) 461 return restore_ops 462 463 def _single_restore(self): 464 """Restores the trackable.""" 465 trackable = self.trackable 466 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 467 checkpoint = self.checkpoint 468 # If the UID of this restore is lower than our current update UID, we don't 469 # need to actually restore the object. 470 if checkpoint.restore_uid > trackable._update_uid: # pylint: disable=protected-access 471 restore_ops, tensor_saveables, python_positions, registered_savers = ( 472 self.gather_ops_or_named_saveables()) 473 trackable._update_uid = checkpoint.restore_uid # pylint: disable=protected-access 474 else: 475 restore_ops = () 476 tensor_saveables = {} 477 python_positions = () 478 registered_savers = {} 479 return restore_ops, tensor_saveables, python_positions, registered_savers 480 481 482def _queue_children_for_restoration(checkpoint_position, visit_queue): 483 """Queues the restoration of trackable's children or defers them.""" 484 # pylint: disable=protected-access 485 trackable = checkpoint_position.trackable 486 for child in checkpoint_position.object_proto.children: 487 child_position = checkpoint_position.create_child_position(child.node_id) 488 local_object = trackable._lookup_dependency(child.local_name) 489 child_proto = child_position.object_proto 490 if local_object is None: 491 # We don't yet have a dependency registered with this name. Save it 492 # in case we do. 493 if child_proto.HasField("has_checkpoint_values"): 494 has_value = child_proto.has_checkpoint_values.value 495 else: 496 # If the field is not set, do a simple check to see if the dependency 497 # has children and/or checkpointed values. 498 has_value = bool( 499 child_proto.children or child_proto.attributes or 500 child_proto.slot_variables or 501 child_proto.HasField("registered_saver")) 502 if has_value: 503 trackable._deferred_dependencies.setdefault(child.local_name, 504 []).append(child_position) 505 else: 506 if child_position.bind_object(trackable=local_object): 507 # This object's correspondence is new, so dependencies need to be 508 # visited. Delay doing it so that we get a breadth-first dependency 509 # resolution order (shallowest paths first). The caller is responsible 510 # for emptying visit_queue. 511 visit_queue.append((child_position, local_object)) 512 513 514_DeferredSlotVariableRestoration = collections.namedtuple( 515 "_DeferredSlotVariableRestoration", [ 516 "original_variable", 517 "slot_variable_id", 518 "slot_name", 519 ]) 520 521 522def _queue_slot_variables(checkpoint_position, visit_queue): 523 """Queues slot variables for restoration.""" 524 trackable = checkpoint_position.trackable 525 checkpoint = checkpoint_position.checkpoint 526 for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop( 527 checkpoint_position.proto_id, ())): 528 slot_variable_position, slot_variable = ( 529 checkpoint_position.create_slot_variable_position( 530 trackable, deferred_slot_restoration.original_variable, 531 deferred_slot_restoration.slot_variable_id, 532 deferred_slot_restoration.slot_name)) 533 if slot_variable_position is not None: 534 visit_queue.append((slot_variable_position, slot_variable)) 535 for slot_restoration in checkpoint.slot_restorations.pop( 536 checkpoint_position.proto_id, ()): 537 optimizer_object = checkpoint.object_by_proto_id.get( 538 slot_restoration.optimizer_id, None) 539 if optimizer_object is None: 540 # The optimizer has not yet been created or tracked. Record in the 541 # checkpoint that the slot variables need to be restored when it is. 542 checkpoint.deferred_slot_restorations.setdefault( 543 slot_restoration.optimizer_id, []).append( 544 _DeferredSlotVariableRestoration( 545 original_variable=trackable, 546 slot_variable_id=slot_restoration.slot_variable_id, 547 slot_name=slot_restoration.slot_name)) 548 549 # `optimizer_object` can be a `Checkpoint` when user only needs the 550 # attributes the optimizer holds, such as `iterations`. In those cases, 551 # it would not have the optimizer's `_create_or_restore_slot_variable` 552 # method. 553 elif hasattr(optimizer_object, "_create_or_restore_slot_variable"): 554 slot_variable_position, slot_variable = ( 555 checkpoint_position.create_slot_variable_position( 556 optimizer_object, trackable, slot_restoration.slot_variable_id, 557 slot_restoration.slot_name)) 558 if slot_variable_position is not None: 559 visit_queue.append((slot_variable_position, slot_variable)) 560 561 562def _extract_saveable_name(checkpoint_key): 563 # Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/" 564 search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/" 565 return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)] 566 567 568def _get_saveable_from_factory(saveable_factories, serialized_tensor, 569 created_compat_names): 570 """Returns the saveable generated from the factory method.""" 571 matched_factory = None 572 573 # The `expected_factory_name` is used to find the right saveable factory, 574 # while the `factory_input_name` is the value that is passed to the factory 575 # method to instantiate the SaveableObject. 576 expected_factory_name = serialized_tensor.name 577 factory_input_name = serialized_tensor.checkpoint_key 578 579 # Case 1: the name already exactly matches a key in saveable_factories. 580 if expected_factory_name in saveable_factories: 581 matched_factory = saveable_factories[expected_factory_name] 582 583 # Case 2: (Forward compat) The serialized name is composed of 584 # "factory_name" + "SUFFIX". Get the matching factory name. 585 if matched_factory is None: 586 587 for factory_name, factory in saveable_factories.items(): 588 if expected_factory_name.startswith(factory_name): 589 if matched_factory is not None: 590 # This condition is met in the extreme edge case where the object 591 # returns two saveable factories with similar names. This is very 592 # unlikely because there zero objects inside TensorFlow that use 593 # more than one saveable factory. 594 raise ValueError("Forward compatibility load error: Unable to load " 595 "checkpoint saved in future version of TensorFlow. " 596 "Please update your version of TensorFlow to the " 597 "version in which the checkpoint was saved.") 598 599 matched_factory = factory 600 factory_input_name = _extract_saveable_name( 601 serialized_tensor.checkpoint_key) + factory_name 602 created_compat_names.add(factory_name) 603 604 if callable(matched_factory): 605 return matched_factory(name=factory_input_name) 606 return matched_factory 607