1"""An object-local variable management scheme.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22 23import six 24 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_io_ops as io_ops 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.training.saving import saveable_object 34from tensorflow.python.util import tf_contextlib 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util.tf_export import tf_export 37 38# Key where the object graph proto is saved in a TensorBundle 39OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" 40 41# A key indicating a variable's value in an object's checkpointed Tensors 42# (Trackable._gather_saveables_for_checkpoint). If this is the only key and 43# the object has no dependencies, then its value may be restored on object 44# creation (avoiding double assignment when executing eagerly). 45VARIABLE_VALUE_KEY = "VARIABLE_VALUE" 46OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON" 47 48 49@tf_export("__internal__.tracking.TrackableReference", v1=[]) 50class TrackableReference( 51 collections.namedtuple("TrackableReference", ["name", "ref"])): 52 """A named reference to a trackable object for use with the `Trackable` class. 53 54 These references mark named `Trackable` dependencies of a `Trackable` object 55 and should be created when overriding `Trackable._checkpoint_dependencies`. 56 57 Attributes: 58 name: The local name for this dependency. 59 ref: The `Trackable` object being referenced. 60 """ 61 62 63# TODO(bfontain): Update once sharded initialization interface is finalized. 64ShardInfo = collections.namedtuple( 65 "CheckpointInitialValueShardInfo", ["shape", "offset"]) 66 67 68@tf_export("__internal__.tracking.CheckpointInitialValueCallable", v1=[]) 69class CheckpointInitialValueCallable(object): 70 """A callable object that returns a CheckpointInitialValue. 71 72 See CheckpointInitialValue for more information. 73 """ 74 75 def __init__(self, checkpoint_position): 76 self._checkpoint_position = checkpoint_position 77 78 @property 79 def checkpoint_position(self): 80 return self._checkpoint_position 81 82 def __call__(self, shape=None, dtype=None, shard_info=None): 83 # Note that the signature here is for compatibility with normal callable 84 # initializers which take shape and dtype. Although dtype isn't used, it 85 # will get passed in by a functool.partial_wrapper in places like 86 # base_layer_utils.py's make_variable. 87 return CheckpointInitialValue( 88 self._checkpoint_position, shape, shard_info=shard_info) 89 90 @property 91 def restore_uid(self): 92 return self._checkpoint_position.restore_uid 93 94 95@tf_export("__internal__.tracking.CheckpointInitialValue", v1=[]) 96class CheckpointInitialValue(ops.Tensor): 97 """Tensor wrapper for managing update UIDs in `Variables`. 98 99 When supplied as an initial value, objects of this type let a `Variable` 100 (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial 101 value came from. This allows deferred restorations to be sequenced in the 102 order the user specified them, and lets us fall back on assignment if an 103 initial value is not set (e.g. due to a custom getter interfering). 104 105 See comments in _add_variable_with_custom_getter for more information about 106 how `CheckpointInitialValue` is used. 107 """ 108 109 def __init__(self, checkpoint_position, shape=None, shard_info=None): 110 if shard_info: 111 full_shape_str = " ".join("%d" % d for d in shape) + " " 112 slice_spec = ":".join( 113 "%d,%d" % (o, s) for o, s in zip(shard_info.offset, shard_info.shape)) 114 shape_and_slice = full_shape_str + slice_spec 115 else: 116 shape_and_slice = "" 117 self.wrapped_value = checkpoint_position.value_tensors( 118 {VARIABLE_VALUE_KEY: shape_and_slice})[VARIABLE_VALUE_KEY] 119 self._checkpoint_position = checkpoint_position 120 121 def __getattr__(self, attr): 122 try: 123 return getattr(self.wrapped_value, attr) 124 except AttributeError: 125 return self.__getattribute__(attr) 126 127 @property 128 def checkpoint_position(self): 129 return self._checkpoint_position 130 131 132class NoRestoreSaveable(saveable_object.SaveableObject): 133 """Embeds a tensor in a checkpoint with no restore ops.""" 134 135 def __init__(self, tensor, name, dtype=None, device=None): 136 spec = saveable_object.SaveSpec( 137 tensor, "", name, dtype=dtype, device=device) 138 super(NoRestoreSaveable, self).__init__(tensor, [spec], name) 139 140 def restore(self, restored_tensors, restored_shapes): 141 return control_flow_ops.no_op() 142 143 144@six.add_metaclass(abc.ABCMeta) 145class PythonStateSaveable(saveable_object.SaveableObject): 146 """An interface for saving/restoring volatile Python state.""" 147 148 @abc.abstractmethod 149 def feed_dict_additions(self): 150 """When running a graph, indicates fresh state to feed. 151 152 Returns: 153 A dictionary mapping `Tensor`s to current Python state. 154 """ 155 pass 156 157 @abc.abstractmethod 158 def freeze(self): 159 """Create a new `SaveableObject` which freezes current state as a constant. 160 161 Used when executing eagerly to embed the current state as a constant, or 162 when creating a static tf.compat.v1.train.Saver with the frozen current 163 Python state. 164 165 Returns: 166 A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has 167 no Python state associated with it). 168 """ 169 pass 170 171 172class PythonStringStateSaveable(PythonStateSaveable): 173 """Saves Python state in a checkpoint.""" 174 175 def __init__(self, name, state_callback, restore_callback=None): 176 """Configure saving. 177 178 Args: 179 name: The checkpoint key to write to. 180 state_callback: A function taking no arguments which returns a string. 181 This function is run every time a checkpoint is written. 182 restore_callback: A function taking a Python string, used to restore 183 state. Optional; defaults to doing nothing, in which case it is ignored 184 by status assertions such as assert_consumed(). 185 """ 186 self._has_trivial_state_callback = (restore_callback is None) 187 188 def _state_callback_wrapper(): 189 with ops.init_scope(): 190 return state_callback() 191 192 self._state_callback = _state_callback_wrapper 193 self._restore_callback = restore_callback 194 with ops.device("/cpu:0"): 195 self._save_string = constant_op.constant("", dtype=dtypes.string) 196 spec = saveable_object.SaveSpec( 197 self._save_string, "", name, dtype=dtypes.string) 198 super(PythonStringStateSaveable, self).__init__(self._save_string, [spec], 199 name) 200 201 @property 202 def optional_restore(self): 203 """For values with no restore, relaxes assert_consumed().""" 204 return self._has_trivial_state_callback 205 206 def feed_dict_additions(self): 207 """When running a graph, indicates fresh state to feed.""" 208 return {self._save_string: self._state_callback()} 209 210 def freeze(self): 211 """Create a frozen `SaveableObject` which saves the current state.""" 212 213 def _constant_state(): 214 return constant_op.constant(self._state_callback(), dtype=dtypes.string) 215 216 return NoRestoreSaveable( 217 tensor=_constant_state, 218 dtype=dtypes.string, 219 name=self.name, 220 device="cpu:0") 221 222 def python_restore(self, restored_strings): 223 """Called to restore Python state.""" 224 if self._restore_callback: 225 restored, = restored_strings 226 self._restore_callback(restored) 227 228 def restore(self, restored_tensors, restored_shapes): 229 """Called to restore TensorFlow state (nothing to do).""" 230 return control_flow_ops.no_op() 231 232 233class CheckpointPosition(object): 234 """Indicates a position within a `_CheckpointRestoreCoordinator`.""" 235 236 __slots__ = ["_checkpoint", "_proto_id"] 237 238 def __init__(self, checkpoint, proto_id): 239 """Specify an object within a checkpoint. 240 241 Args: 242 checkpoint: A _CheckpointRestoreCoordinator object. 243 proto_id: The index of this object in TrackableObjectGraph.nodes. 244 """ 245 self._checkpoint = checkpoint 246 self._proto_id = proto_id 247 248 def restore(self, trackable): 249 """Restore this value into `trackable`.""" 250 with ops.init_scope(): 251 if self.bind_object(trackable): 252 # This object's correspondence with a checkpointed object is new, so 253 # process deferred restorations for it and its dependencies. 254 restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access 255 if restore_ops: 256 self._checkpoint.new_restore_ops(restore_ops) 257 258 def bind_object(self, trackable): 259 """Set a checkpoint<->object correspondence and process slot variables. 260 261 Args: 262 trackable: The object to record a correspondence for. 263 264 Returns: 265 True if this is a new assignment, False if this object has already been 266 mapped to a checkpointed `Object` proto. 267 Raises: 268 AssertionError: If another object is already bound to the `Object` proto. 269 """ 270 checkpoint = self.checkpoint 271 checkpoint.all_python_objects.add(trackable) 272 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) 273 checkpoint.matched_proto_ids.add(self._proto_id) 274 if current_assignment is None: 275 checkpoint.object_by_proto_id[self._proto_id] = trackable 276 for deferred_slot_restoration in ( 277 checkpoint.deferred_slot_restorations.pop(self._proto_id, ())): 278 trackable._create_or_restore_slot_variable( # pylint: disable=protected-access 279 slot_variable_position=CheckpointPosition( 280 checkpoint=checkpoint, 281 proto_id=deferred_slot_restoration.slot_variable_id), 282 variable=deferred_slot_restoration.original_variable, 283 slot_name=deferred_slot_restoration.slot_name) 284 for slot_restoration in checkpoint.slot_restorations.pop( 285 self._proto_id, ()): 286 optimizer_object = checkpoint.object_by_proto_id.get( 287 slot_restoration.optimizer_id, None) 288 if optimizer_object is None: 289 # The optimizer has not yet been created or tracked. Record in the 290 # checkpoint that the slot variables need to be restored when it is. 291 checkpoint.deferred_slot_restorations.setdefault( 292 slot_restoration.optimizer_id, []).append( 293 _DeferredSlotVariableRestoration( 294 original_variable=trackable, 295 slot_variable_id=slot_restoration.slot_variable_id, 296 slot_name=slot_restoration.slot_name)) 297 298 # `optimizer_object` can be a `Checkpoint` when user only needs the 299 # attributes the optimizer holds, such as `iterations`. In those cases, 300 # it would not have the optimizer's `_create_or_restore_slot_variable` 301 # method. 302 elif hasattr(optimizer_object, "_create_or_restore_slot_variable"): 303 optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access 304 slot_variable_position=CheckpointPosition( 305 checkpoint=checkpoint, 306 proto_id=slot_restoration.slot_variable_id), 307 variable=trackable, 308 slot_name=slot_restoration.slot_name) 309 return True # New assignment 310 else: 311 # The object was already mapped for this checkpoint load, which means 312 # we don't need to do anything besides check that the mapping is 313 # consistent (if the dependency DAG is not a tree then there are 314 # multiple paths to the same object). 315 if current_assignment is not trackable: 316 logging.warning(( 317 "Inconsistent references when loading the checkpoint into this " 318 "object graph. Either the Trackable object references in the " 319 "Python program have changed in an incompatible way, or the " 320 "checkpoint was generated in an incompatible program.\n\nTwo " 321 "checkpoint references resolved to different objects (%s and %s)."), 322 current_assignment, trackable) 323 return False # Not a new assignment 324 325 def is_simple_variable(self): 326 """Determine whether this value is restorable with a Tensor initializer.""" 327 attributes = self.object_proto.attributes 328 return (len(attributes) == 1 and 329 attributes[0].name == VARIABLE_VALUE_KEY and 330 not self.object_proto.children) 331 332 def value_tensors(self, shape_and_slices=None): 333 """Create value `Tensor`s for this object's attributes. 334 335 Does not require that the Python object has been created. Used for 336 restore-on-create when executing eagerly. 337 338 Args: 339 shape_and_slices: A dict mapping from object attribute names to a shape 340 and slice string that will be passed to a RestoreV2 op. If the dict is 341 None or if an object attribute is not in the dict, the full tensor will 342 be restored. 343 344 Returns: 345 A dictionary mapping from object attribute names to `Tensor`s. 346 """ 347 value_tensors = {} 348 for serialized_tensor in self.object_proto.attributes: 349 checkpoint_key = serialized_tensor.checkpoint_key 350 dtype = self._checkpoint.dtype_map[checkpoint_key] 351 base_type = dtype.base_dtype 352 io_device = self._checkpoint.options.experimental_io_device or "cpu:0" 353 with ops.init_scope(): 354 with ops.device(io_device): 355 # Run the restore itself on the io_device(CPU or specified). 356 if (shape_and_slices is not None and 357 serialized_tensor.name in shape_and_slices): 358 shape_and_slice = shape_and_slices[serialized_tensor.name] 359 else: 360 shape_and_slice = "" 361 value, = io_ops.restore_v2( 362 prefix=self._checkpoint.save_path_tensor, 363 tensor_names=[checkpoint_key], 364 shape_and_slices=[shape_and_slice], 365 dtypes=[base_type], 366 name="%s_checkpoint_read" % (serialized_tensor.name,)) 367 # Copy the value to the current device if necessary. 368 value_tensors[serialized_tensor.name] = array_ops.identity(value) 369 return value_tensors 370 371 def gather_ops_or_named_saveables(self): 372 """Looks up or creates SaveableObjects which don't have cached ops.""" 373 saveables = self.trackable._gather_saveables_for_checkpoint() # pylint: disable=protected-access 374 # Name saveables based on the name this object had when it was checkpointed. 375 named_saveables = {} 376 python_saveables = [] 377 existing_restore_ops = [] 378 for serialized_tensor in self.object_proto.attributes: 379 if context.executing_eagerly(): 380 existing_op = None 381 else: 382 existing_op = self._checkpoint.restore_ops_by_name.get( 383 serialized_tensor.checkpoint_key, None) 384 if existing_op is not None: 385 existing_restore_ops.append(existing_op) 386 continue 387 388 # Only if we don't have cached ops for this SaveableObject, we'll see if 389 # the SaveableObject itself has been cached. If not, we'll make it, and 390 # either way we'll extract new ops from it (or if it has Python state to 391 # restore, we'll run that). 392 saveables_cache = self._checkpoint.graph_view.saveables_cache 393 if saveables_cache is None: 394 # No SaveableObject caching when executing eagerly. 395 saveable = None 396 else: 397 # If we've already created and cached a SaveableObject for this 398 # attribute, we can re-use it to avoid re-creating some ops when graph 399 # building. 400 saveable_list = saveables_cache.get(self.trackable, 401 {}).get(serialized_tensor.name, 402 (None,)) 403 if len(saveable_list) == 1: 404 # Almost every attribute will have exactly one SaveableObject. 405 saveable, = saveable_list 406 else: 407 # Don't use cached SaveableObjects for partitioned variables, which is 408 # the only case where we'd have a list of SaveableObjects. Op caching 409 # will catch them. 410 saveable = None 411 if saveable is not None: 412 # The name of this attribute has changed, so we need to re-generate 413 # the SaveableObject. 414 if serialized_tensor.checkpoint_key not in saveable.name: 415 saveable = None 416 del saveables_cache[self.trackable] 417 if saveable is None: 418 # If there was no cached SaveableObject, we should check if the Python 419 # object has the attribute. 420 saveable_factory = saveables.get(serialized_tensor.name, None) 421 if saveable_factory is None: 422 # Purposefully does not throw an exception if attributes have been 423 # added or deleted. Stores unused attributes so an exception can be 424 # raised if the user decides to check that everything in the 425 # checkpoint was loaded. 426 if not serialized_tensor.optional_restore: 427 self._checkpoint.unused_attributes.setdefault( 428 self._proto_id, []).append(serialized_tensor.name) 429 continue 430 if callable(saveable_factory): 431 saveable = saveable_factory(name=serialized_tensor.checkpoint_key) 432 else: 433 saveable = saveable_factory 434 if saveables_cache is not None: 435 saveables_cache.setdefault(self.trackable, 436 {})[serialized_tensor.name] = [saveable] 437 if isinstance(saveable, PythonStateSaveable): 438 python_saveables.append(saveable) 439 else: 440 named_saveables[serialized_tensor.checkpoint_key] = saveable 441 return existing_restore_ops, named_saveables, python_saveables 442 443 def restore_ops(self): 444 """Create or fetch restore ops for this object's attributes. 445 446 Requires that the `Trackable` Python object has been bound to an object 447 ID in the checkpoint. 448 449 Returns: 450 A list of operations when graph building, or an empty list when executing 451 eagerly. 452 """ 453 (restore_ops, tensor_saveables, 454 python_saveables) = self.gather_ops_or_named_saveables() 455 restore_ops.extend( 456 self._checkpoint.restore_saveables(tensor_saveables, python_saveables)) 457 return restore_ops 458 459 @property 460 def checkpoint(self): 461 return self._checkpoint 462 463 @property 464 def trackable(self): 465 return self._checkpoint.object_by_proto_id[self._proto_id] 466 467 @property 468 def object_proto(self): 469 return self._checkpoint.object_graph_proto.nodes[self._proto_id] 470 471 @property 472 def restore_uid(self): 473 return self._checkpoint.restore_uid 474 475 def __repr__(self): 476 return repr(self.object_proto) 477 478 def value_shape(self): 479 """The shape of the VARIABLE_VALUE tensor. 480 481 Returns: 482 If found a TensorShape object, otherwise None. 483 """ 484 for serialized_tensor in self.object_proto.attributes: 485 if serialized_tensor.name == VARIABLE_VALUE_KEY: 486 return self._checkpoint.shape_map[serialized_tensor.checkpoint_key] 487 return None 488 489 490_DeferredSlotVariableRestoration = collections.namedtuple( 491 "_DeferredSlotVariableRestoration", [ 492 "original_variable", 493 "slot_variable_id", 494 "slot_name", 495 ]) 496 497_SlotVariableRestoration = collections.namedtuple( 498 "_SlotVariableRestoration", 499 [ 500 # The checkpoint proto id of the optimizer object. 501 "optimizer_id", 502 # The checkpoint proto id of the slot variable. 503 "slot_variable_id", 504 "slot_name", 505 ]) 506 507 508@tf_export("__internal__.tracking.no_automatic_dependency_tracking", v1=[]) 509def no_automatic_dependency_tracking(method): 510 """Disables automatic dependency tracking on attribute assignment. 511 512 Use to decorate any method of a Trackable object. Attribute assignment in 513 that method will not add dependencies (also respected in Model). Harmless if 514 used in a class which does not do automatic dependency tracking (which means 515 it's safe to use in base classes which may have subclasses which also inherit 516 from Trackable). 517 518 Args: 519 method: The method to decorate. 520 521 Returns: 522 A decorated method which sets and un-sets automatic dependency tracking for 523 the object the method is called on (not thread safe). 524 """ 525 526 def _method_wrapper(self, *args, **kwargs): 527 previous_value = getattr(self, "_self_setattr_tracking", True) 528 self._self_setattr_tracking = False # pylint: disable=protected-access 529 try: 530 result = method(self, *args, **kwargs) 531 finally: 532 self._self_setattr_tracking = previous_value # pylint: disable=protected-access 533 return result 534 535 return tf_decorator.make_decorator( 536 target=method, decorator_func=_method_wrapper) 537 538 539@tf_contextlib.contextmanager 540def no_manual_dependency_tracking_scope(obj): 541 """A context that disables manual dependency tracking for the given `obj`. 542 543 Sometimes library methods might track objects on their own and we might want 544 to disable that and do the tracking on our own. One can then use this context 545 manager to disable the tracking the library method does and do your own 546 tracking. 547 548 For example: 549 550 class TestLayer(tf.keras.Layer): 551 def build(): 552 with no_manual_dependency_tracking_scope(self): 553 var = self.add_variable("name1") # Creates a var and doesn't track it 554 self._track_trackable("name2", var) # We track variable with name `name2` 555 556 Args: 557 obj: A trackable object. 558 559 Yields: 560 a scope in which the object doesn't track dependencies manually. 561 """ 562 # pylint: disable=protected-access 563 previous_value = getattr(obj, "_manual_tracking", True) 564 obj._manual_tracking = False 565 try: 566 yield 567 finally: 568 obj._manual_tracking = previous_value 569 570 571@tf_contextlib.contextmanager 572def no_automatic_dependency_tracking_scope(obj): 573 """A context that disables automatic dependency tracking when assigning attrs. 574 575 Objects that inherit from Autotrackable automatically creates dependencies 576 to trackable objects through attribute assignments, and wraps data structures 577 (lists or dicts) with trackable classes. This scope may be used to temporarily 578 disable this behavior. This works similar to the decorator 579 `no_automatic_dependency_tracking`. 580 581 Example usage: 582 ``` 583 model = tf.keras.Model() 584 model.arr1 = [] # Creates a ListWrapper object 585 with no_automatic_dependency_tracking_scope(model): 586 model.arr2 = [] # Creates a regular, untracked python list 587 ``` 588 589 Args: 590 obj: A trackable object. 591 592 Yields: 593 a scope in which the object doesn't track dependencies. 594 """ 595 previous_value = getattr(obj, "_setattr_tracking", True) 596 obj._setattr_tracking = False # pylint: disable=protected-access 597 try: 598 yield 599 finally: 600 obj._setattr_tracking = previous_value # pylint: disable=protected-access 601 602 603@tf_export("__internal__.tracking.Trackable", v1=[]) 604class Trackable(object): 605 """Base class for `Trackable` objects without automatic dependencies. 606 607 This class has no __setattr__ override for performance reasons. Dependencies 608 must be added explicitly. Unless attribute assignment is performance-critical, 609 use `AutoTrackable` instead. Use `Trackable` for `isinstance` 610 checks. 611 """ 612 613 # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with 614 # _self_. We have some properties to forward semi-public attributes to their 615 # _self_ equivalents. 616 617 @property 618 def _setattr_tracking(self): 619 if not hasattr(self, "_self_setattr_tracking"): 620 self._self_setattr_tracking = True 621 return self._self_setattr_tracking 622 623 @_setattr_tracking.setter 624 def _setattr_tracking(self, value): 625 self._self_setattr_tracking = value 626 627 @property 628 def _update_uid(self): 629 return self._self_update_uid 630 631 @_update_uid.setter 632 def _update_uid(self, value): 633 self._self_update_uid = value 634 635 @property 636 def _unconditional_checkpoint_dependencies(self): 637 return self._self_unconditional_checkpoint_dependencies 638 639 @property 640 def _unconditional_dependency_names(self): 641 return self._self_unconditional_dependency_names 642 643 @property 644 def _name_based_restores(self): 645 return self._self_name_based_restores 646 647 # Trackable does not do automatic dependency tracking, but uses the 648 # no_automatic_dependency_tracking decorator so it can avoid adding 649 # dependencies if a subclass is Trackable / inherits from Model (both of 650 # which have __setattr__ overrides). 651 @no_automatic_dependency_tracking 652 def _maybe_initialize_trackable(self): 653 """Initialize dependency management. 654 655 Not __init__, since most objects will forget to call it. 656 """ 657 if hasattr(self, "_self_unconditional_checkpoint_dependencies"): 658 # __init__ already called. This check means that we don't need 659 # Trackable.__init__() in the constructor of every TensorFlow object. 660 return 661 # A list of TrackableReference objects. Some classes implementing 662 # `Trackable`, notably `Optimizer`s, may override the 663 # _checkpoint_dependencies property with conditional dependencies 664 # (e.g. based on the current graph when saving). 665 self._self_unconditional_checkpoint_dependencies = [] 666 # Maps names -> Trackable objects 667 self._self_unconditional_dependency_names = {} 668 # Restorations for other Trackable objects on which this object may 669 # eventually depend. Maps local name -> CheckpointPosition list. Optimizers 670 # tack on conditional dependencies, and so need separate management of 671 # deferred dependencies too. 672 self._self_unconditional_deferred_dependencies = {} 673 # The UID of the highest assignment to this object. Used to ensure that the 674 # last requested assignment determines the final value of an object. 675 if hasattr(self, "_self_update_uid"): 676 raise AssertionError( 677 "Internal error: the object had an update UID set before its " 678 "initialization code was run.") 679 self._self_update_uid = -1 680 # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator 681 # instances, which should be checked when creating variables or other 682 # saveables. These are passed on recursively to all dependencies, since 683 # unlike object-based checkpoint restores we don't know which subgraph is 684 # being restored in advance. This mechanism is only necessary for 685 # restore-on-create when executing eagerly, and so is unused when graph 686 # building. 687 self._self_name_based_restores = set() 688 689 # Dictionary of SaveableObjects factories. This dictionary is defined when 690 # the object is loaded from the SavedModel. When writing a custom class, 691 # prefer overriding "_gather_saveables_from_checkpoint" to using this 692 # attribute. 693 self._self_saveable_object_factories = {} 694 695 @property 696 def _object_identifier(self): 697 """String used to identify this object in a SavedModel. 698 699 Generally, the object identifier is constant across objects of the same 700 class, while the metadata field is used for instance-specific data. 701 702 Returns: 703 String object identifier. 704 """ 705 return "_generic_user_object" 706 707 def _no_dependency(self, value): 708 """If automatic dependency tracking is enabled, ignores `value`.""" 709 return value 710 711 def _name_based_attribute_restore(self, checkpoint): 712 """Restore the object's attributes from a name-based checkpoint.""" 713 self._self_name_based_restores.add(checkpoint) 714 if self._self_update_uid < checkpoint.restore_uid: 715 checkpoint.eager_restore(self) 716 self._self_update_uid = checkpoint.restore_uid 717 718 @property 719 def _checkpoint_dependencies(self): 720 """All dependencies of this object. 721 722 May be overridden to include conditional dependencies. 723 724 Returns: 725 A list of `TrackableReference` objects indicating named 726 `Trackable` dependencies which should be saved along with this 727 object. 728 """ 729 return self._self_unconditional_checkpoint_dependencies 730 731 @property 732 def _deferred_dependencies(self): 733 """A dictionary with deferred dependencies. 734 735 Stores restorations for other Trackable objects on which this object 736 may eventually depend. May be overridden by sub-classes (e.g. Optimizers use 737 conditional dependencies based the current graph, and so need separate 738 management of deferred dependencies too). 739 740 Returns: 741 A dictionary mapping from local name to a list of CheckpointPosition 742 objects. 743 """ 744 return self._self_unconditional_deferred_dependencies 745 746 def _lookup_dependency(self, name): 747 """Look up a dependency by name. 748 749 May be overridden to include conditional dependencies. 750 751 Args: 752 name: The local name of the dependency. 753 754 Returns: 755 A `Trackable` object, or `None` if no dependency by this name was 756 found. 757 """ 758 return self._self_unconditional_dependency_names.get(name, None) 759 760 def _add_variable_with_custom_getter(self, 761 name, 762 shape=None, 763 dtype=dtypes.float32, 764 initializer=None, 765 getter=None, 766 overwrite=False, 767 **kwargs_for_getter): 768 """Restore-on-create for a variable be saved with this `Trackable`. 769 770 If the user has requested that this object or another `Trackable` which 771 depends on this object be restored from a checkpoint (deferred loading 772 before variable object creation), `initializer` may be ignored and the value 773 from the checkpoint used instead. 774 775 Args: 776 name: A name for the variable. Must be unique within this object. 777 shape: The shape of the variable. 778 dtype: The data type of the variable. 779 initializer: The initializer to use. Ignored if there is a deferred 780 restoration left over from a call to 781 `_restore_from_checkpoint_position`. 782 getter: The getter to wrap which actually fetches the variable. 783 overwrite: If True, disables unique name and type checks. 784 **kwargs_for_getter: Passed to the getter. 785 786 Returns: 787 The new variable object. 788 789 Raises: 790 ValueError: If the variable name is not unique. 791 """ 792 self._maybe_initialize_trackable() 793 with ops.init_scope(): 794 if context.executing_eagerly(): 795 # If this is a variable with a single Tensor stored in the checkpoint, 796 # we can set that value as an initializer rather than initializing and 797 # then assigning (when executing eagerly). This call returns None if 798 # there is nothing to restore. 799 checkpoint_initializer = self._preload_simple_restoration( 800 name=name) 801 else: 802 checkpoint_initializer = None 803 if (checkpoint_initializer is not None and 804 not (isinstance(initializer, CheckpointInitialValueCallable) and 805 (initializer.restore_uid > checkpoint_initializer.restore_uid))): 806 # If multiple Trackable objects are "creating" the same variable 807 # via the magic of custom getters, the one with the highest restore UID 808 # (the one called last) has to make the final initializer. If another 809 # custom getter interrupts this process by overwriting the initializer, 810 # then we'll catch that when we call _track_trackable. So this is 811 # "best effort" to set the initializer with the highest restore UID. 812 initializer = checkpoint_initializer 813 new_variable = getter( 814 name=name, 815 shape=shape, 816 dtype=dtype, 817 initializer=initializer, 818 **kwargs_for_getter) 819 820 # If we set an initializer and the variable processed it, tracking will not 821 # assign again. It will add this variable to our dependencies, and if there 822 # is a non-trivial restoration queued, it will handle that. This also 823 # handles slot variables. 824 if not overwrite or isinstance(new_variable, Trackable): 825 return self._track_trackable(new_variable, name=name, overwrite=overwrite) 826 else: 827 # TODO(allenl): Some variable types are not yet supported. Remove this 828 # fallback once all get_variable() return types are Trackable. 829 return new_variable 830 831 def _preload_simple_restoration(self, name): 832 """Return a dependency's value for restore-on-create. 833 834 Note the restoration is not deleted; if for some reason preload is called 835 and then not assigned to the variable (for example because a custom getter 836 overrides the initializer), the assignment will still happen once the 837 variable is tracked (determined based on checkpoint.restore_uid). 838 839 Args: 840 name: The object-local name of the dependency holding the variable's 841 value. 842 843 Returns: 844 An callable for use as a variable's initializer/initial_value, or None if 845 one should not be set (either because there was no variable with this name 846 in the checkpoint or because it needs more complex deserialization). Any 847 non-trivial deserialization will happen when the variable object is 848 tracked. 849 """ 850 deferred_dependencies_list = self._deferred_dependencies.get(name, ()) 851 if not deferred_dependencies_list: 852 # Nothing to do; we don't have a restore for this dependency queued up. 853 return 854 for checkpoint_position in deferred_dependencies_list: 855 if not checkpoint_position.is_simple_variable(): 856 # If _any_ pending restoration is too complicated to fit in an 857 # initializer (because it has dependencies, or because there are 858 # multiple Tensors to restore), bail and let the general tracking code 859 # handle it. 860 return None 861 checkpoint_position = max( 862 deferred_dependencies_list, 863 key=lambda restore: restore.checkpoint.restore_uid) 864 return CheckpointInitialValueCallable( 865 checkpoint_position=checkpoint_position) 866 867 def _track_trackable(self, trackable, name, overwrite=False): 868 """Declare a dependency on another `Trackable` object. 869 870 Indicates that checkpoints for this object should include variables from 871 `trackable`. 872 873 Variables in a checkpoint are mapped to `Trackable`s based on the names 874 provided when the checkpoint was written. To avoid breaking existing 875 checkpoints when modifying a class, neither variable names nor dependency 876 names (the names passed to `_track_trackable`) may change. 877 878 Args: 879 trackable: A `Trackable` which this object depends on. 880 name: A local name for `trackable`, used for loading checkpoints into the 881 correct objects. 882 overwrite: Boolean, whether silently replacing dependencies is OK. Used 883 for __setattr__, where throwing an error on attribute reassignment would 884 be inappropriate. 885 886 Returns: 887 `trackable`, for convenience when declaring a dependency and 888 assigning to a member variable in one statement. 889 890 Raises: 891 TypeError: If `trackable` does not inherit from `Trackable`. 892 ValueError: If another object is already tracked by this name. 893 """ 894 self._maybe_initialize_trackable() 895 if not isinstance(trackable, Trackable): 896 raise TypeError(("Trackable._track_trackable() passed type %s, not a " 897 "Trackable.") % (type(trackable),)) 898 if not getattr(self, "_manual_tracking", True): 899 return trackable 900 new_reference = TrackableReference(name=name, ref=trackable) 901 current_object = self._lookup_dependency(name) 902 if (current_object is not None and current_object is not trackable): 903 if not overwrite: 904 raise ValueError( 905 ("Called Trackable._track_trackable() with name='%s', " 906 "but a Trackable with this name is already declared as a " 907 "dependency. Names must be unique (or overwrite=True).") % (name,)) 908 # This is a weird thing to do, but we're not going to stop people from 909 # using __setattr__. 910 for index, (old_name, _) in enumerate( 911 self._self_unconditional_checkpoint_dependencies): 912 if name == old_name: 913 self._self_unconditional_checkpoint_dependencies[ 914 index] = new_reference 915 elif current_object is None: 916 self._self_unconditional_checkpoint_dependencies.append(new_reference) 917 self._handle_deferred_dependencies(name=name, trackable=trackable) 918 self._self_unconditional_dependency_names[name] = trackable 919 return trackable 920 921 def _handle_deferred_dependencies(self, name, trackable): 922 """Pop and load any deferred checkpoint restores into `trackable`. 923 924 This method does not add a new dependency on `trackable`, but it does 925 check if any outstanding/deferred dependencies have been queued waiting for 926 this dependency to be added (matched based on `name`). If so, 927 `trackable` and its dependencies are restored. The restorations are 928 considered fulfilled and so are deleted. 929 930 `_track_trackable` is more appropriate for adding a 931 normal/unconditional dependency, and includes handling for deferred 932 restorations. This method allows objects such as `Optimizer` to use the same 933 restoration logic while managing conditional dependencies themselves, by 934 overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the 935 object's dependencies based on the context it is saved/restored in (a single 936 optimizer instance can have state associated with multiple graphs). 937 938 Args: 939 name: The name of the dependency within this object (`self`), used to 940 match `trackable` with values saved in a checkpoint. 941 trackable: The Trackable object to restore (inheriting from `Trackable`). 942 """ 943 self._maybe_initialize_trackable() 944 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 945 deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) 946 for checkpoint_position in sorted( 947 deferred_dependencies_list, 948 key=lambda restore: restore.checkpoint.restore_uid, 949 reverse=True): 950 checkpoint_position.restore(trackable) 951 952 # Pass on any name-based restores queued in this object. 953 for name_based_restore in sorted( 954 self._self_name_based_restores, 955 key=lambda checkpoint: checkpoint.restore_uid, 956 reverse=True): 957 trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access 958 959 def _restore_from_checkpoint_position(self, checkpoint_position): 960 """Restore this object and its dependencies (may be deferred).""" 961 # Attempt a breadth-first traversal, since presumably the user has more 962 # control over shorter paths. If we don't have all of the dependencies at 963 # this point, the end result is not breadth-first (since other deferred 964 # traversals will happen later). 965 visit_queue = collections.deque([checkpoint_position]) 966 restore_ops = [] 967 tensor_saveables = {} 968 python_saveables = [] 969 while visit_queue: 970 current_position = visit_queue.popleft() 971 new_restore_ops, new_tensor_saveables, new_python_saveables = ( 972 current_position.trackable # pylint: disable=protected-access 973 ._single_restoration_from_checkpoint_position( 974 checkpoint_position=current_position, 975 visit_queue=visit_queue)) 976 restore_ops.extend(new_restore_ops) 977 tensor_saveables.update(new_tensor_saveables) 978 python_saveables.extend(new_python_saveables) 979 restore_ops.extend( 980 current_position.checkpoint.restore_saveables( 981 tensor_saveables, python_saveables)) 982 return restore_ops 983 984 def _single_restoration_from_checkpoint_position(self, checkpoint_position, 985 visit_queue): 986 """Restore this object, and either queue its dependencies or defer them.""" 987 self._maybe_initialize_trackable() 988 checkpoint = checkpoint_position.checkpoint 989 # If the UID of this restore is lower than our current update UID, we don't 990 # need to actually restore the object. However, we should pass the 991 # restoration on to our dependencies. 992 if checkpoint.restore_uid > self._self_update_uid: 993 restore_ops, tensor_saveables, python_saveables = ( 994 checkpoint_position.gather_ops_or_named_saveables()) 995 self._self_update_uid = checkpoint.restore_uid 996 else: 997 restore_ops = () 998 tensor_saveables = {} 999 python_saveables = () 1000 for child in checkpoint_position.object_proto.children: 1001 child_position = CheckpointPosition( 1002 checkpoint=checkpoint, proto_id=child.node_id) 1003 local_object = self._lookup_dependency(child.local_name) 1004 if local_object is None: 1005 # We don't yet have a dependency registered with this name. Save it 1006 # in case we do. 1007 self._deferred_dependencies.setdefault(child.local_name, 1008 []).append(child_position) 1009 else: 1010 if child_position.bind_object(trackable=local_object): 1011 # This object's correspondence is new, so dependencies need to be 1012 # visited. Delay doing it so that we get a breadth-first dependency 1013 # resolution order (shallowest paths first). The caller is responsible 1014 # for emptying visit_queue. 1015 visit_queue.append(child_position) 1016 return restore_ops, tensor_saveables, python_saveables 1017 1018 def _gather_saveables_for_checkpoint(self): 1019 """Returns a dictionary of values to checkpoint with this object. 1020 1021 Keys in the returned dictionary are local to this object and in a separate 1022 namespace from dependencies. Values may either be `SaveableObject` factories 1023 or variables easily converted to `SaveableObject`s (as in 1024 `tf.compat.v1.train.Saver`'s 1025 `var_list` constructor argument). 1026 1027 `SaveableObjects` have a name set, which Trackable needs to generate 1028 itself. So rather than returning `SaveableObjects` directly, this method 1029 should return a dictionary of callables which take `name` arguments and 1030 return `SaveableObjects` with that name. 1031 1032 If this object may also be passed to the global-name-based 1033 `tf.compat.v1.train.Saver`, 1034 the returned callables should have a default value for their name argument 1035 (i.e. be callable with no arguments). 1036 1037 Returned values must be saved only by this object; if any value may be 1038 shared, it should instead be a dependency. For example, variable objects 1039 save their own values with the key `VARIABLE_VALUE_KEY`, but objects which 1040 reference variables simply add a dependency. 1041 1042 Returns: 1043 The dictionary mapping attribute names to `SaveableObject` factories 1044 described above. For example: 1045 {VARIABLE_VALUE_KEY: 1046 lambda name="global_name_for_this_object": 1047 SaveableObject(name=name, ...)} 1048 """ 1049 return self._self_saveable_object_factories 1050 1051 def _list_extra_dependencies_for_serialization(self, serialization_cache): 1052 """Lists extra dependencies to serialize. 1053 1054 Internal sub-classes can override this method to return extra dependencies 1055 that should be saved with the object during SavedModel serialization. For 1056 example, this is used to save `trainable_variables` in Keras models. The 1057 python property `trainable_variables` contains logic to iterate through the 1058 weights from the model and its sublayers. The serialized Keras model saves 1059 `trainable_weights` as a trackable list of variables. 1060 1061 PLEASE NOTE when overriding this method: 1062 1. This function may only generate new trackable objects the first time it 1063 is called. 1064 2. The returned dictionary must not have naming conflicts with 1065 dependencies tracked by the root. In other words, if the root is 1066 tracking `object_1` with name 'x', and this functions returns 1067 `{'x': object_2}`, an error is raised when saving. 1068 1069 Args: 1070 serialization_cache: A dictionary shared between all objects in the same 1071 object graph. This object is passed to both 1072 `_list_extra_dependencies_for_serialization` and 1073 `_list_functions_for_serialization`. 1074 1075 Returns: 1076 A dictionary mapping attribute names to trackable objects. 1077 """ 1078 del serialization_cache 1079 return dict() 1080 1081 def _list_functions_for_serialization(self, serialization_cache): 1082 """Lists the functions of this trackable to serialize. 1083 1084 Internal sub-classes can override this with specific logic. E.g. 1085 `AutoTrackable` provides an implementation that returns the `attr` 1086 that return functions. 1087 1088 Args: 1089 serialization_cache: Dictionary passed to all objects in the same object 1090 graph during serialization. 1091 1092 Returns: 1093 A dictionary mapping attribute names to `Function` or 1094 `ConcreteFunction`. 1095 """ 1096 del serialization_cache 1097 return dict() 1098 1099 def _map_resources(self, save_options): # pylint: disable=unused-argument 1100 """Makes new resource handle ops corresponding to existing resource tensors. 1101 1102 Internal sub-classes can override this to inform model saving how to add new 1103 resource handle ops to the main GraphDef of a SavedModel (TF 1.x style 1104 graph), which allows session based APIs (e.g, C++ loader API) to interact 1105 with resources owned by this object. 1106 1107 Args: 1108 save_options: A tf.saved_model.SaveOptions instance. 1109 1110 Returns: 1111 A tuple of (object_map, resource_map): 1112 object_map: A dictionary mapping from objects that hold existing 1113 resource tensors to replacement objects created to hold the new 1114 resource tensors. 1115 resource_map: A dictionary mapping from existing resource tensors to 1116 newly created resource tensors. 1117 """ 1118 return {}, {} 1119