1"""An object-local variable management scheme.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22 23import six 24 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_io_ops as io_ops 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.training.saving import saveable_object 34from tensorflow.python.util import tf_contextlib 35from tensorflow.python.util import tf_decorator 36 37# Key where the object graph proto is saved in a TensorBundle 38OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" 39 40# A key indicating a variable's value in an object's checkpointed Tensors 41# (Trackable._gather_saveables_for_checkpoint). If this is the only key and 42# the object has no dependencies, then its value may be restored on object 43# creation (avoiding double assignment when executing eagerly). 44VARIABLE_VALUE_KEY = "VARIABLE_VALUE" 45OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON" 46 47TrackableReference = collections.namedtuple( 48 "TrackableReference", 49 [ 50 # The local name for this dependency. 51 "name", 52 # The Trackable object being referenced. 53 "ref" 54 ]) 55 56 57class CheckpointInitialValue(ops.Tensor): 58 """Tensor wrapper for managing update UIDs in `Variables`. 59 60 When supplied as an initial value, objects of this type let a `Variable` 61 (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial 62 value came from. This allows deferred restorations to be sequenced in the 63 order the user specified them, and lets us fall back on assignment if an 64 initial value is not set (e.g. due to a custom getter interfering). 65 66 See comments in _add_variable_with_custom_getter for more information about 67 how `CheckpointInitialValue` is used. 68 """ 69 70 def __init__(self, checkpoint_position, shape=None): 71 self.wrapped_value = checkpoint_position.value_tensors()[VARIABLE_VALUE_KEY] 72 if shape: 73 # We need to set the static shape information on the initializer if 74 # possible so we don't get a variable with an unknown shape. 75 self.wrapped_value.set_shape(shape) 76 self._checkpoint_position = checkpoint_position 77 78 def __getattr__(self, attr): 79 try: 80 return getattr(self.wrapped_value, attr) 81 except AttributeError: 82 return self.__getattribute__(attr) 83 84 @property 85 def checkpoint_position(self): 86 return self._checkpoint_position 87 88 89class NoRestoreSaveable(saveable_object.SaveableObject): 90 """Embeds a tensor in a checkpoint with no restore ops.""" 91 92 def __init__(self, tensor, name, dtype=None, device=None): 93 spec = saveable_object.SaveSpec( 94 tensor, "", name, dtype=dtype, device=device) 95 super(NoRestoreSaveable, self).__init__(tensor, [spec], name) 96 97 def restore(self, restored_tensors, restored_shapes): 98 return control_flow_ops.no_op() 99 100 101@six.add_metaclass(abc.ABCMeta) 102class PythonStateSaveable(saveable_object.SaveableObject): 103 """An interface for saving/restoring volatile Python state.""" 104 105 @abc.abstractmethod 106 def feed_dict_additions(self): 107 """When running a graph, indicates fresh state to feed. 108 109 Returns: 110 A dictionary mapping `Tensor`s to current Python state. 111 """ 112 pass 113 114 @abc.abstractmethod 115 def freeze(self): 116 """Create a new `SaveableObject` which freezes current state as a constant. 117 118 Used when executing eagerly to embed the current state as a constant, or 119 when creating a static tf.compat.v1.train.Saver with the frozen current 120 Python state. 121 122 Returns: 123 A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has 124 no Python state associated with it). 125 """ 126 pass 127 128 129class PythonStringStateSaveable(PythonStateSaveable): 130 """Saves Python state in a checkpoint.""" 131 132 def __init__(self, name, state_callback, restore_callback=None): 133 """Configure saving. 134 135 Args: 136 name: The checkpoint key to write to. 137 state_callback: A function taking no arguments which returns a string. 138 This function is run every time a checkpoint is written. 139 restore_callback: A function taking a Python string, used to restore 140 state. Optional; defaults to doing nothing, in which case it is ignored 141 by status assertions such as assert_consumed(). 142 """ 143 self._has_trivial_state_callback = (restore_callback is None) 144 145 def _state_callback_wrapper(): 146 with ops.init_scope(): 147 return state_callback() 148 149 self._state_callback = _state_callback_wrapper 150 self._restore_callback = restore_callback 151 with ops.device("/cpu:0"): 152 self._save_string = constant_op.constant("", dtype=dtypes.string) 153 spec = saveable_object.SaveSpec( 154 self._save_string, "", name, dtype=dtypes.string) 155 super(PythonStringStateSaveable, self).__init__(self._save_string, [spec], 156 name) 157 158 @property 159 def optional_restore(self): 160 """For values with no restore, relaxes assert_consumed().""" 161 return self._has_trivial_state_callback 162 163 def feed_dict_additions(self): 164 """When running a graph, indicates fresh state to feed.""" 165 return {self._save_string: self._state_callback()} 166 167 def freeze(self): 168 """Create a frozen `SaveableObject` which saves the current state.""" 169 170 def _constant_state(): 171 return constant_op.constant(self._state_callback(), dtype=dtypes.string) 172 173 return NoRestoreSaveable( 174 tensor=_constant_state, 175 dtype=dtypes.string, 176 name=self.name, 177 device="cpu:0") 178 179 def python_restore(self, restored_strings): 180 """Called to restore Python state.""" 181 if self._restore_callback: 182 restored, = restored_strings 183 self._restore_callback(restored) 184 185 def restore(self, restored_tensors, restored_shapes): 186 """Called to restore TensorFlow state (nothing to do).""" 187 return control_flow_ops.no_op() 188 189 190class CheckpointPosition(object): 191 """Indicates a position within a `_CheckpointRestoreCoordinator`.""" 192 193 def __init__(self, checkpoint, proto_id): 194 """Specify an object within a checkpoint. 195 196 Args: 197 checkpoint: A _CheckpointRestoreCoordinator object. 198 proto_id: The index of this object in TrackableObjectGraph.nodes. 199 """ 200 self._checkpoint = checkpoint 201 self._proto_id = proto_id 202 203 def restore(self, trackable): 204 """Restore this value into `trackable`.""" 205 with ops.init_scope(): 206 if self.bind_object(trackable): 207 # This object's correspondence with a checkpointed object is new, so 208 # process deferred restorations for it and its dependencies. 209 restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access 210 if restore_ops: 211 self._checkpoint.new_restore_ops(restore_ops) 212 213 def bind_object(self, trackable): 214 """Set a checkpoint<->object correspondence and process slot variables. 215 216 Args: 217 trackable: The object to record a correspondence for. 218 219 Returns: 220 True if this is a new assignment, False if this object has already been 221 mapped to a checkpointed `Object` proto. 222 Raises: 223 AssertionError: If another object is already bound to the `Object` proto. 224 """ 225 checkpoint = self.checkpoint 226 checkpoint.all_python_objects.add(trackable) 227 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) 228 checkpoint.matched_proto_ids.add(self._proto_id) 229 if current_assignment is None: 230 checkpoint.object_by_proto_id[self._proto_id] = trackable 231 for deferred_slot_restoration in ( 232 checkpoint.deferred_slot_restorations.pop(self._proto_id, ())): 233 trackable._create_or_restore_slot_variable( # pylint: disable=protected-access 234 slot_variable_position=CheckpointPosition( 235 checkpoint=checkpoint, 236 proto_id=deferred_slot_restoration.slot_variable_id), 237 variable=deferred_slot_restoration.original_variable, 238 slot_name=deferred_slot_restoration.slot_name) 239 for slot_restoration in checkpoint.slot_restorations.pop( 240 self._proto_id, ()): 241 optimizer_object = checkpoint.object_by_proto_id.get( 242 slot_restoration.optimizer_id, None) 243 if optimizer_object is None: 244 # The optimizer has not yet been created or tracked. Record in the 245 # checkpoint that the slot variables need to be restored when it is. 246 checkpoint.deferred_slot_restorations.setdefault( 247 slot_restoration.optimizer_id, []).append( 248 _DeferredSlotVariableRestoration( 249 original_variable=trackable, 250 slot_variable_id=slot_restoration.slot_variable_id, 251 slot_name=slot_restoration.slot_name)) 252 else: 253 optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access 254 slot_variable_position=CheckpointPosition( 255 checkpoint=checkpoint, 256 proto_id=slot_restoration.slot_variable_id), 257 variable=trackable, 258 slot_name=slot_restoration.slot_name) 259 return True # New assignment 260 else: 261 # The object was already mapped for this checkpoint load, which means 262 # we don't need to do anything besides check that the mapping is 263 # consistent (if the dependency DAG is not a tree then there are 264 # multiple paths to the same object). 265 if current_assignment is not trackable: 266 logging.warning(( 267 "Inconsistent references when loading the checkpoint into this " 268 "object graph. Either the Trackable object references in the " 269 "Python program have changed in an incompatible way, or the " 270 "checkpoint was generated in an incompatible program.\n\nTwo " 271 "checkpoint references resolved to different objects (%s and %s)."), 272 current_assignment, trackable) 273 return False # Not a new assignment 274 275 def is_simple_variable(self): 276 """Determine whether this value is restorable with a Tensor initializer.""" 277 attributes = self.object_proto.attributes 278 return (len(attributes) == 1 and 279 attributes[0].name == VARIABLE_VALUE_KEY and 280 not self.object_proto.children) 281 282 def value_tensors(self): 283 """Create value `Tensor`s for this object's attributes. 284 285 Does not require that the Python object has been created. Used for 286 restore-on-create when executing eagerly. 287 288 Returns: 289 A dictionary mapping from object attribute names to `Tensor`s. 290 """ 291 value_tensors = {} 292 for serialized_tensor in self.object_proto.attributes: 293 checkpoint_key = serialized_tensor.checkpoint_key 294 dtype = self._checkpoint.dtype_map[checkpoint_key] 295 base_type = dtype.base_dtype 296 with ops.init_scope(): 297 with ops.device("/cpu:0"): 298 # Run the restore itself on the CPU. 299 value, = io_ops.restore_v2( 300 prefix=self._checkpoint.save_path_tensor, 301 tensor_names=[checkpoint_key], 302 shape_and_slices=[""], 303 dtypes=[base_type], 304 name="%s_checkpoint_read" % (serialized_tensor.name,)) 305 # Copy the value to the current device if necessary. 306 value_tensors[serialized_tensor.name] = array_ops.identity(value) 307 return value_tensors 308 309 def gather_ops_or_named_saveables(self): 310 """Looks up or creates SaveableObjects which don't have cached ops.""" 311 saveables = self.trackable._gather_saveables_for_checkpoint() # pylint: disable=protected-access 312 # Name saveables based on the name this object had when it was checkpointed. 313 named_saveables = {} 314 python_saveables = [] 315 existing_restore_ops = [] 316 for serialized_tensor in self.object_proto.attributes: 317 if context.executing_eagerly(): 318 existing_op = None 319 else: 320 existing_op = self._checkpoint.restore_ops_by_name.get( 321 serialized_tensor.checkpoint_key, None) 322 if existing_op is not None: 323 existing_restore_ops.append(existing_op) 324 continue 325 326 # Only if we don't have cached ops for this SaveableObject, we'll see if 327 # the SaveableObject itself has been cached. If not, we'll make it, and 328 # either way we'll extract new ops from it (or if it has Python state to 329 # restore, we'll run that). 330 saveables_cache = self._checkpoint.graph_view.saveables_cache 331 if saveables_cache is None: 332 # No SaveableObject caching when executing eagerly. 333 saveable = None 334 else: 335 # If we've already created and cached a SaveableObject for this 336 # attribute, we can re-use it to avoid re-creating some ops when graph 337 # building. 338 saveable_list = saveables_cache.get(self.trackable, 339 {}).get(serialized_tensor.name, 340 (None,)) 341 if len(saveable_list) == 1: 342 # Almost every attribute will have exactly one SaveableObject. 343 saveable, = saveable_list 344 else: 345 # Don't use cached SaveableObjects for partitioned variables, which is 346 # the only case where we'd have a list of SaveableObjects. Op caching 347 # will catch them. 348 saveable = None 349 if saveable is not None: 350 # The name of this attribute has changed, so we need to re-generate 351 # the SaveableObject. 352 if serialized_tensor.checkpoint_key not in saveable.name: 353 saveable = None 354 del saveables_cache[self.trackable] 355 if saveable is None: 356 # If there was no cached SaveableObject, we should check if the Python 357 # object has the attribute. 358 saveable_factory = saveables.get(serialized_tensor.name, None) 359 if saveable_factory is None: 360 # Purposefully does not throw an exception if attributes have been 361 # added or deleted. Stores unused attributes so an exception can be 362 # raised if the user decides to check that everything in the 363 # checkpoint was loaded. 364 if not serialized_tensor.optional_restore: 365 self._checkpoint.unused_attributes.setdefault( 366 self._proto_id, []).append(serialized_tensor.name) 367 continue 368 if callable(saveable_factory): 369 saveable = saveable_factory(name=serialized_tensor.checkpoint_key) 370 else: 371 saveable = saveable_factory 372 if saveables_cache is not None: 373 saveables_cache.setdefault(self.trackable, 374 {})[serialized_tensor.name] = [saveable] 375 if isinstance(saveable, PythonStateSaveable): 376 python_saveables.append(saveable) 377 else: 378 named_saveables[serialized_tensor.checkpoint_key] = saveable 379 return existing_restore_ops, named_saveables, python_saveables 380 381 def restore_ops(self): 382 """Create or fetch restore ops for this object's attributes. 383 384 Requires that the `Trackable` Python object has been bound to an object 385 ID in the checkpoint. 386 387 Returns: 388 A list of operations when graph building, or an empty list when executing 389 eagerly. 390 """ 391 (restore_ops, tensor_saveables, 392 python_saveables) = self.gather_ops_or_named_saveables() 393 restore_ops.extend( 394 self._checkpoint.restore_saveables(tensor_saveables, python_saveables)) 395 return restore_ops 396 397 @property 398 def checkpoint(self): 399 return self._checkpoint 400 401 @property 402 def trackable(self): 403 return self._checkpoint.object_by_proto_id[self._proto_id] 404 405 @property 406 def object_proto(self): 407 return self._checkpoint.object_graph_proto.nodes[self._proto_id] 408 409 @property 410 def restore_uid(self): 411 return self._checkpoint.restore_uid 412 413 def __repr__(self): 414 return repr(self.object_proto) 415 416 417_DeferredSlotVariableRestoration = collections.namedtuple( 418 "_DeferredSlotVariableRestoration", [ 419 "original_variable", 420 "slot_variable_id", 421 "slot_name", 422 ]) 423 424_SlotVariableRestoration = collections.namedtuple( 425 "_SlotVariableRestoration", 426 [ 427 # The checkpoint proto id of the optimizer object. 428 "optimizer_id", 429 # The checkpoint proto id of the slot variable. 430 "slot_variable_id", 431 "slot_name", 432 ]) 433 434 435def no_automatic_dependency_tracking(method): 436 """Disables automatic dependency tracking on attribute assignment. 437 438 Use to decorate any method of a Trackable object. Attribute assignment in 439 that method will not add dependencies (also respected in Model). Harmless if 440 used in a class which does not do automatic dependency tracking (which means 441 it's safe to use in base classes which may have subclasses which also inherit 442 from Trackable). 443 444 Args: 445 method: The method to decorate. 446 447 Returns: 448 A decorated method which sets and un-sets automatic dependency tracking for 449 the object the method is called on (not thread safe). 450 """ 451 452 def _method_wrapper(self, *args, **kwargs): 453 previous_value = getattr(self, "_self_setattr_tracking", True) 454 self._self_setattr_tracking = False # pylint: disable=protected-access 455 try: 456 result = method(self, *args, **kwargs) 457 finally: 458 self._self_setattr_tracking = previous_value # pylint: disable=protected-access 459 return result 460 461 return tf_decorator.make_decorator( 462 target=method, decorator_func=_method_wrapper) 463 464 465@tf_contextlib.contextmanager 466def no_manual_dependency_tracking_scope(obj): 467 """A context that disables manual dependency tracking for the given `obj`. 468 469 Sometimes library methods might track objects on their own and we might want 470 to disable that and do the tracking on our own. One can then use this context 471 manager to disable the tracking the library method does and do your own 472 tracking. 473 474 For example: 475 476 class TestLayer(tf.keras.Layer): 477 def build(): 478 with no_manual_dependency_tracking_scope(self): 479 var = self.add_variable("name1") # Creates a var and doesn't track it 480 self._track_trackable("name2", var) # We track variable with name `name2` 481 482 Args: 483 obj: A trackable object. 484 485 Yields: 486 a scope in which the object doesn't track dependencies manually. 487 """ 488 # pylint: disable=protected-access 489 previous_value = getattr(obj, "_manual_tracking", True) 490 obj._manual_tracking = False 491 try: 492 yield 493 finally: 494 obj._manual_tracking = previous_value 495 496 497@tf_contextlib.contextmanager 498def no_automatic_dependency_tracking_scope(obj): 499 """A context that disables automatic dependency tracking when assigning attrs. 500 501 Objects that inherit from Autotrackable automatically creates dependencies 502 to trackable objects through attribute assignments, and wraps data structures 503 (lists or dicts) with trackable classes. This scope may be used to temporarily 504 disable this behavior. This works similar to the decorator 505 `no_automatic_dependency_tracking`. 506 507 Example usage: 508 ``` 509 model = tf.keras.Model() 510 model.arr1 = [] # Creates a ListWrapper object 511 with no_automatic_dependency_tracking_scope(model): 512 model.arr2 = [] # Creates a regular, untracked python list 513 ``` 514 515 Args: 516 obj: A trackable object. 517 518 Yields: 519 a scope in which the object doesn't track dependencies. 520 """ 521 previous_value = getattr(obj, "_setattr_tracking", True) 522 obj._setattr_tracking = False # pylint: disable=protected-access 523 try: 524 yield 525 finally: 526 obj._setattr_tracking = previous_value # pylint: disable=protected-access 527 528 529class Trackable(object): 530 """Base class for `Trackable` objects without automatic dependencies. 531 532 This class has no __setattr__ override for performance reasons. Dependencies 533 must be added explicitly. Unless attribute assignment is performance-critical, 534 use `AutoTrackable` instead. Use `Trackable` for `isinstance` 535 checks. 536 """ 537 538 # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with 539 # _self_. We have some properties to forward semi-public attributes to their 540 # _self_ equivalents. 541 542 @property 543 def _setattr_tracking(self): 544 if not hasattr(self, "_self_setattr_tracking"): 545 self._self_setattr_tracking = True 546 return self._self_setattr_tracking 547 548 @_setattr_tracking.setter 549 def _setattr_tracking(self, value): 550 self._self_setattr_tracking = value 551 552 @property 553 def _update_uid(self): 554 return self._self_update_uid 555 556 @_update_uid.setter 557 def _update_uid(self, value): 558 self._self_update_uid = value 559 560 @property 561 def _unconditional_checkpoint_dependencies(self): 562 return self._self_unconditional_checkpoint_dependencies 563 564 @property 565 def _unconditional_dependency_names(self): 566 return self._self_unconditional_dependency_names 567 568 @property 569 def _name_based_restores(self): 570 return self._self_name_based_restores 571 572 # Trackable does not do automatic dependency tracking, but uses the 573 # no_automatic_dependency_tracking decorator so it can avoid adding 574 # dependencies if a subclass is Trackable / inherits from Model (both of 575 # which have __setattr__ overrides). 576 @no_automatic_dependency_tracking 577 def _maybe_initialize_trackable(self): 578 """Initialize dependency management. 579 580 Not __init__, since most objects will forget to call it. 581 """ 582 if hasattr(self, "_self_unconditional_checkpoint_dependencies"): 583 # __init__ already called. This check means that we don't need 584 # Trackable.__init__() in the constructor of every TensorFlow object. 585 return 586 # A list of TrackableReference objects. Some classes implementing 587 # `Trackable`, notably `Optimizer`s, may override the 588 # _checkpoint_dependencies property with conditional dependencies 589 # (e.g. based on the current graph when saving). 590 self._self_unconditional_checkpoint_dependencies = [] 591 # Maps names -> Trackable objects 592 self._self_unconditional_dependency_names = {} 593 # Restorations for other Trackable objects on which this object may 594 # eventually depend. Maps local name -> CheckpointPosition list. Optimizers 595 # tack on conditional dependencies, and so need separate management of 596 # deferred dependencies too. 597 self._self_unconditional_deferred_dependencies = {} 598 # The UID of the highest assignment to this object. Used to ensure that the 599 # last requested assignment determines the final value of an object. 600 if hasattr(self, "_self_update_uid"): 601 raise AssertionError( 602 "Internal error: the object had an update UID set before its " 603 "initialization code was run.") 604 self._self_update_uid = -1 605 # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator 606 # instances, which should be checked when creating variables or other 607 # saveables. These are passed on recursively to all dependencies, since 608 # unlike object-based checkpoint restores we don't know which subgraph is 609 # being restored in advance. This mechanism is only necessary for 610 # restore-on-create when executing eagerly, and so is unused when graph 611 # building. 612 self._self_name_based_restores = set() 613 614 @property 615 def _object_identifier(self): 616 """String used to identify this object in a SavedModel. 617 618 Generally, the object identifier is constant across objects of the same 619 class, while the metadata field is used for instance-specific data. 620 621 Returns: 622 String object identifier. 623 """ 624 return "_generic_user_object" 625 626 @property 627 def _tracking_metadata(self): 628 """String containing object metadata, which is saved to the SavedModel.""" 629 return "" 630 631 def _no_dependency(self, value): 632 """If automatic dependency tracking is enabled, ignores `value`.""" 633 return value 634 635 def _name_based_attribute_restore(self, checkpoint): 636 """Restore the object's attributes from a name-based checkpoint.""" 637 self._self_name_based_restores.add(checkpoint) 638 if self._self_update_uid < checkpoint.restore_uid: 639 checkpoint.eager_restore(self) 640 self._self_update_uid = checkpoint.restore_uid 641 642 @property 643 def _checkpoint_dependencies(self): 644 """All dependencies of this object. 645 646 May be overridden to include conditional dependencies. 647 648 Returns: 649 A list of `TrackableReference` objects indicating named 650 `Trackable` dependencies which should be saved along with this 651 object. 652 """ 653 return self._self_unconditional_checkpoint_dependencies 654 655 @property 656 def _deferred_dependencies(self): 657 """A dictionary with deferred dependencies. 658 659 Stores restorations for other Trackable objects on which this object 660 may eventually depend. May be overridden by sub-classes (e.g. Optimizers use 661 conditional dependencies based the current graph, and so need separate 662 management of deferred dependencies too). 663 664 Returns: 665 A dictionary mapping from local name to a list of CheckpointPosition 666 objects. 667 """ 668 return self._self_unconditional_deferred_dependencies 669 670 def _lookup_dependency(self, name): 671 """Look up a dependency by name. 672 673 May be overridden to include conditional dependencies. 674 675 Args: 676 name: The local name of the dependency. 677 678 Returns: 679 A `Trackable` object, or `None` if no dependency by this name was 680 found. 681 """ 682 return self._self_unconditional_dependency_names.get(name, None) 683 684 def _add_variable_with_custom_getter(self, 685 name, 686 shape=None, 687 dtype=dtypes.float32, 688 initializer=None, 689 getter=None, 690 overwrite=False, 691 **kwargs_for_getter): 692 """Restore-on-create for a variable be saved with this `Trackable`. 693 694 If the user has requested that this object or another `Trackable` which 695 depends on this object be restored from a checkpoint (deferred loading 696 before variable object creation), `initializer` may be ignored and the value 697 from the checkpoint used instead. 698 699 Args: 700 name: A name for the variable. Must be unique within this object. 701 shape: The shape of the variable. 702 dtype: The data type of the variable. 703 initializer: The initializer to use. Ignored if there is a deferred 704 restoration left over from a call to 705 `_restore_from_checkpoint_position`. 706 getter: The getter to wrap which actually fetches the variable. 707 overwrite: If True, disables unique name and type checks. 708 **kwargs_for_getter: Passed to the getter. 709 710 Returns: 711 The new variable object. 712 713 Raises: 714 ValueError: If the variable name is not unique. 715 """ 716 self._maybe_initialize_trackable() 717 with ops.init_scope(): 718 if context.executing_eagerly(): 719 # If this is a variable with a single Tensor stored in the checkpoint, 720 # we can set that value as an initializer rather than initializing and 721 # then assigning (when executing eagerly). This call returns None if 722 # there is nothing to restore. 723 checkpoint_initializer = self._preload_simple_restoration( 724 name=name, shape=shape) 725 else: 726 checkpoint_initializer = None 727 if (checkpoint_initializer is not None and 728 not (isinstance(initializer, CheckpointInitialValue) and 729 (initializer.restore_uid > checkpoint_initializer.restore_uid))): 730 # If multiple Trackable objects are "creating" the same variable 731 # via the magic of custom getters, the one with the highest restore UID 732 # (the one called last) has to make the final initializer. If another 733 # custom getter interrupts this process by overwriting the initializer, 734 # then we'll catch that when we call _track_trackable. So this is 735 # "best effort" to set the initializer with the highest restore UID. 736 initializer = checkpoint_initializer 737 shape = None 738 new_variable = getter( 739 name=name, 740 shape=shape, 741 dtype=dtype, 742 initializer=initializer, 743 **kwargs_for_getter) 744 745 # If we set an initializer and the variable processed it, tracking will not 746 # assign again. It will add this variable to our dependencies, and if there 747 # is a non-trivial restoration queued, it will handle that. This also 748 # handles slot variables. 749 if not overwrite or isinstance(new_variable, Trackable): 750 return self._track_trackable(new_variable, name=name, overwrite=overwrite) 751 else: 752 # TODO(allenl): Some variable types are not yet supported. Remove this 753 # fallback once all get_variable() return types are Trackable. 754 return new_variable 755 756 def _preload_simple_restoration(self, name, shape): 757 """Return a dependency's value for restore-on-create. 758 759 Note the restoration is not deleted; if for some reason preload is called 760 and then not assigned to the variable (for example because a custom getter 761 overrides the initializer), the assignment will still happen once the 762 variable is tracked (determined based on checkpoint.restore_uid). 763 764 Args: 765 name: The object-local name of the dependency holding the variable's 766 value. 767 shape: The shape of the variable being loaded into. 768 769 Returns: 770 An callable for use as a variable's initializer/initial_value, or None if 771 one should not be set (either because there was no variable with this name 772 in the checkpoint or because it needs more complex deserialization). Any 773 non-trivial deserialization will happen when the variable object is 774 tracked. 775 """ 776 deferred_dependencies_list = self._deferred_dependencies.get(name, ()) 777 if not deferred_dependencies_list: 778 # Nothing to do; we don't have a restore for this dependency queued up. 779 return 780 for checkpoint_position in deferred_dependencies_list: 781 if not checkpoint_position.is_simple_variable(): 782 # If _any_ pending restoration is too complicated to fit in an 783 # initializer (because it has dependencies, or because there are 784 # multiple Tensors to restore), bail and let the general tracking code 785 # handle it. 786 return None 787 checkpoint_position = max( 788 deferred_dependencies_list, 789 key=lambda restore: restore.checkpoint.restore_uid) 790 return CheckpointInitialValue( 791 checkpoint_position=checkpoint_position, shape=shape) 792 793 def _track_trackable(self, trackable, name, overwrite=False): 794 """Declare a dependency on another `Trackable` object. 795 796 Indicates that checkpoints for this object should include variables from 797 `trackable`. 798 799 Variables in a checkpoint are mapped to `Trackable`s based on the names 800 provided when the checkpoint was written. To avoid breaking existing 801 checkpoints when modifying a class, neither variable names nor dependency 802 names (the names passed to `_track_trackable`) may change. 803 804 Args: 805 trackable: A `Trackable` which this object depends on. 806 name: A local name for `trackable`, used for loading checkpoints into the 807 correct objects. 808 overwrite: Boolean, whether silently replacing dependencies is OK. Used 809 for __setattr__, where throwing an error on attribute reassignment would 810 be inappropriate. 811 812 Returns: 813 `trackable`, for convenience when declaring a dependency and 814 assigning to a member variable in one statement. 815 816 Raises: 817 TypeError: If `trackable` does not inherit from `Trackable`. 818 ValueError: If another object is already tracked by this name. 819 """ 820 self._maybe_initialize_trackable() 821 if not isinstance(trackable, Trackable): 822 raise TypeError(("Trackable._track_trackable() passed type %s, not a " 823 "Trackable.") % (type(trackable),)) 824 if not getattr(self, "_manual_tracking", True): 825 return trackable 826 new_reference = TrackableReference(name=name, ref=trackable) 827 current_object = self._lookup_dependency(name) 828 if (current_object is not None and current_object is not trackable): 829 if not overwrite: 830 raise ValueError( 831 ("Called Trackable._track_trackable() with name='%s', " 832 "but a Trackable with this name is already declared as a " 833 "dependency. Names must be unique (or overwrite=True).") % (name,)) 834 # This is a weird thing to do, but we're not going to stop people from 835 # using __setattr__. 836 for index, (old_name, _) in enumerate( 837 self._self_unconditional_checkpoint_dependencies): 838 if name == old_name: 839 self._self_unconditional_checkpoint_dependencies[ 840 index] = new_reference 841 elif current_object is None: 842 self._self_unconditional_checkpoint_dependencies.append(new_reference) 843 self._handle_deferred_dependencies(name=name, trackable=trackable) 844 self._self_unconditional_dependency_names[name] = trackable 845 return trackable 846 847 def _handle_deferred_dependencies(self, name, trackable): 848 """Pop and load any deferred checkpoint restores into `trackable`. 849 850 This method does not add a new dependency on `trackable`, but it does 851 check if any outstanding/deferred dependencies have been queued waiting for 852 this dependency to be added (matched based on `name`). If so, 853 `trackable` and its dependencies are restored. The restorations are 854 considered fulfilled and so are deleted. 855 856 `_track_trackable` is more appropriate for adding a 857 normal/unconditional dependency, and includes handling for deferred 858 restorations. This method allows objects such as `Optimizer` to use the same 859 restoration logic while managing conditional dependencies themselves, by 860 overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the 861 object's dependencies based on the context it is saved/restored in (a single 862 optimizer instance can have state associated with multiple graphs). 863 864 Args: 865 name: The name of the dependency within this object (`self`), used to 866 match `trackable` with values saved in a checkpoint. 867 trackable: The Trackable object to restore (inheriting from `Trackable`). 868 """ 869 self._maybe_initialize_trackable() 870 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 871 deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) 872 for checkpoint_position in sorted( 873 deferred_dependencies_list, 874 key=lambda restore: restore.checkpoint.restore_uid, 875 reverse=True): 876 checkpoint_position.restore(trackable) 877 878 # Pass on any name-based restores queued in this object. 879 for name_based_restore in sorted( 880 self._self_name_based_restores, 881 key=lambda checkpoint: checkpoint.restore_uid, 882 reverse=True): 883 trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access 884 885 def _restore_from_checkpoint_position(self, checkpoint_position): 886 """Restore this object and its dependencies (may be deferred).""" 887 # Attempt a breadth-first traversal, since presumably the user has more 888 # control over shorter paths. If we don't have all of the dependencies at 889 # this point, the end result is not breadth-first (since other deferred 890 # traversals will happen later). 891 visit_queue = collections.deque([checkpoint_position]) 892 restore_ops = [] 893 tensor_saveables = {} 894 python_saveables = [] 895 while visit_queue: 896 current_position = visit_queue.popleft() 897 new_restore_ops, new_tensor_saveables, new_python_saveables = ( 898 current_position.trackable # pylint: disable=protected-access 899 ._single_restoration_from_checkpoint_position( 900 checkpoint_position=current_position, 901 visit_queue=visit_queue)) 902 restore_ops.extend(new_restore_ops) 903 tensor_saveables.update(new_tensor_saveables) 904 python_saveables.extend(new_python_saveables) 905 restore_ops.extend( 906 current_position.checkpoint.restore_saveables( 907 tensor_saveables, python_saveables)) 908 return restore_ops 909 910 def _single_restoration_from_checkpoint_position(self, checkpoint_position, 911 visit_queue): 912 """Restore this object, and either queue its dependencies or defer them.""" 913 self._maybe_initialize_trackable() 914 checkpoint = checkpoint_position.checkpoint 915 # If the UID of this restore is lower than our current update UID, we don't 916 # need to actually restore the object. However, we should pass the 917 # restoration on to our dependencies. 918 if checkpoint.restore_uid > self._self_update_uid: 919 restore_ops, tensor_saveables, python_saveables = ( 920 checkpoint_position.gather_ops_or_named_saveables()) 921 self._self_update_uid = checkpoint.restore_uid 922 else: 923 restore_ops = () 924 tensor_saveables = {} 925 python_saveables = () 926 for child in checkpoint_position.object_proto.children: 927 child_position = CheckpointPosition( 928 checkpoint=checkpoint, proto_id=child.node_id) 929 local_object = self._lookup_dependency(child.local_name) 930 if local_object is None: 931 # We don't yet have a dependency registered with this name. Save it 932 # in case we do. 933 self._deferred_dependencies.setdefault(child.local_name, 934 []).append(child_position) 935 else: 936 if child_position.bind_object(trackable=local_object): 937 # This object's correspondence is new, so dependencies need to be 938 # visited. Delay doing it so that we get a breadth-first dependency 939 # resolution order (shallowest paths first). The caller is responsible 940 # for emptying visit_queue. 941 visit_queue.append(child_position) 942 return restore_ops, tensor_saveables, python_saveables 943 944 def _gather_saveables_for_checkpoint(self): 945 """Returns a dictionary of values to checkpoint with this object. 946 947 Keys in the returned dictionary are local to this object and in a separate 948 namespace from dependencies. Values may either be `SaveableObject` factories 949 or variables easily converted to `SaveableObject`s (as in 950 `tf.compat.v1.train.Saver`'s 951 `var_list` constructor argument). 952 953 `SaveableObjects` have a name set, which Trackable needs to generate 954 itself. So rather than returning `SaveableObjects` directly, this method 955 should return a dictionary of callables which take `name` arguments and 956 return `SaveableObjects` with that name. 957 958 If this object may also be passed to the global-name-based 959 `tf.compat.v1.train.Saver`, 960 the returned callables should have a default value for their name argument 961 (i.e. be callable with no arguments). 962 963 Returned values must be saved only by this object; if any value may be 964 shared, it should instead be a dependency. For example, variable objects 965 save their own values with the key `VARIABLE_VALUE_KEY`, but objects which 966 reference variables simply add a dependency. 967 968 Returns: 969 The dictionary mapping attribute names to `SaveableObject` factories 970 described above. For example: 971 {VARIABLE_VALUE_KEY: 972 lambda name="global_name_for_this_object": 973 SaveableObject(name=name, ...)} 974 """ 975 return {} 976 977 def _list_extra_dependencies_for_serialization(self, serialization_cache): 978 """Lists extra dependencies to serialize. 979 980 Internal sub-classes can override this method to return extra dependencies 981 that should be saved with the object during SavedModel serialization. For 982 example, this is used to save `trainable_variables` in Keras models. The 983 python property `trainable_variables` contains logic to iterate through the 984 weights from the model and its sublayers. The serialized Keras model saves 985 `trainable_weights` as a trackable list of variables. 986 987 PLEASE NOTE when overriding this method: 988 1. This function may only generate new trackable objects the first time it 989 is called. 990 2. The returned dictionary must not have naming conflicts with 991 dependencies tracked by the root. In other words, if the root is 992 tracking `object_1` with name 'x', and this functions returns 993 `{'x': object_2}`, an error is raised when saving. 994 995 Args: 996 serialization_cache: A dictionary shared between all objects in the same 997 object graph. This object is passed to both 998 `_list_extra_dependencies_for_serialization` and 999 `_list_functions_for_serialization`. 1000 1001 Returns: 1002 A dictionary mapping attribute names to trackable objects. 1003 """ 1004 del serialization_cache 1005 return dict() 1006 1007 def _list_functions_for_serialization(self, serialization_cache): 1008 """Lists the functions of this trackable to serialize. 1009 1010 Internal sub-classes can override this with specific logic. E.g. 1011 `AutoTrackable` provides an implementation that returns the `attr` 1012 that return functions. 1013 1014 Args: 1015 serialization_cache: Dictionary passed to all objects in the same object 1016 graph during serialization. 1017 1018 Returns: 1019 A dictionary mapping attribute names to `Function` or 1020 `ConcreteFunction`. 1021 """ 1022 del serialization_cache 1023 return dict() 1024