1"""An object-local variable management scheme.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22 23import six 24 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_io_ops as io_ops 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.training.saving import saveable_object 34from tensorflow.python.util import tf_contextlib 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util.tf_export import tf_export 37 38# Key where the object graph proto is saved in a TensorBundle 39OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" 40 41# A key indicating a variable's value in an object's checkpointed Tensors 42# (Trackable._gather_saveables_for_checkpoint). If this is the only key and 43# the object has no dependencies, then its value may be restored on object 44# creation (avoiding double assignment when executing eagerly). 45VARIABLE_VALUE_KEY = "VARIABLE_VALUE" 46OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON" 47 48TrackableReference = collections.namedtuple( 49 "TrackableReference", 50 [ 51 # The local name for this dependency. 52 "name", 53 # The Trackable object being referenced. 54 "ref" 55 ]) 56 57 58# TODO(bfontain): Update once sharded initialization interface is finalized. 59ShardInfo = collections.namedtuple( 60 "CheckpointInitialValueShardInfo", ["shape", "offset"]) 61 62 63class CheckpointInitialValueCallable(object): 64 """A callable object that returns a CheckpointInitialValue. 65 66 See CheckpointInitialValue for more information. 67 """ 68 69 def __init__(self, checkpoint_position): 70 self._checkpoint_position = checkpoint_position 71 72 @property 73 def checkpoint_position(self): 74 return self._checkpoint_position 75 76 def __call__(self, shape=None, dtype=None, shard_info=None): 77 # Note that the signature here is for compatibility with normal callable 78 # initializers which take shape and dtype. Although dtype isn't used, it 79 # will get passed in by a functool.partial_wrapper in places like 80 # base_layer_utils.py's make_variable. 81 return CheckpointInitialValue( 82 self._checkpoint_position, shape, shard_info=shard_info) 83 84 @property 85 def restore_uid(self): 86 return self._checkpoint_position.restore_uid 87 88 89class CheckpointInitialValue(ops.Tensor): 90 """Tensor wrapper for managing update UIDs in `Variables`. 91 92 When supplied as an initial value, objects of this type let a `Variable` 93 (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial 94 value came from. This allows deferred restorations to be sequenced in the 95 order the user specified them, and lets us fall back on assignment if an 96 initial value is not set (e.g. due to a custom getter interfering). 97 98 See comments in _add_variable_with_custom_getter for more information about 99 how `CheckpointInitialValue` is used. 100 """ 101 102 def __init__(self, checkpoint_position, shape=None, shard_info=None): 103 if shard_info: 104 full_shape_str = " ".join("%d" % d for d in shape) + " " 105 slice_spec = ":".join( 106 "%d,%d" % (o, s) for o, s in zip(shard_info.offset, shard_info.shape)) 107 shape_and_slice = full_shape_str + slice_spec 108 # Override shape here so we set the correct shape below. 109 shape = shard_info.shape 110 else: 111 shape_and_slice = "" 112 self.wrapped_value = checkpoint_position.value_tensors( 113 {VARIABLE_VALUE_KEY: shape_and_slice})[VARIABLE_VALUE_KEY] 114 if shape: 115 # We need to set the static shape information on the initializer if 116 # possible so we don't get a variable with an unknown shape. 117 self.wrapped_value.set_shape(shape) 118 self._checkpoint_position = checkpoint_position 119 120 def __getattr__(self, attr): 121 try: 122 return getattr(self.wrapped_value, attr) 123 except AttributeError: 124 return self.__getattribute__(attr) 125 126 @property 127 def checkpoint_position(self): 128 return self._checkpoint_position 129 130 131class NoRestoreSaveable(saveable_object.SaveableObject): 132 """Embeds a tensor in a checkpoint with no restore ops.""" 133 134 def __init__(self, tensor, name, dtype=None, device=None): 135 spec = saveable_object.SaveSpec( 136 tensor, "", name, dtype=dtype, device=device) 137 super(NoRestoreSaveable, self).__init__(tensor, [spec], name) 138 139 def restore(self, restored_tensors, restored_shapes): 140 return control_flow_ops.no_op() 141 142 143@six.add_metaclass(abc.ABCMeta) 144class PythonStateSaveable(saveable_object.SaveableObject): 145 """An interface for saving/restoring volatile Python state.""" 146 147 @abc.abstractmethod 148 def feed_dict_additions(self): 149 """When running a graph, indicates fresh state to feed. 150 151 Returns: 152 A dictionary mapping `Tensor`s to current Python state. 153 """ 154 pass 155 156 @abc.abstractmethod 157 def freeze(self): 158 """Create a new `SaveableObject` which freezes current state as a constant. 159 160 Used when executing eagerly to embed the current state as a constant, or 161 when creating a static tf.compat.v1.train.Saver with the frozen current 162 Python state. 163 164 Returns: 165 A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has 166 no Python state associated with it). 167 """ 168 pass 169 170 171class PythonStringStateSaveable(PythonStateSaveable): 172 """Saves Python state in a checkpoint.""" 173 174 def __init__(self, name, state_callback, restore_callback=None): 175 """Configure saving. 176 177 Args: 178 name: The checkpoint key to write to. 179 state_callback: A function taking no arguments which returns a string. 180 This function is run every time a checkpoint is written. 181 restore_callback: A function taking a Python string, used to restore 182 state. Optional; defaults to doing nothing, in which case it is ignored 183 by status assertions such as assert_consumed(). 184 """ 185 self._has_trivial_state_callback = (restore_callback is None) 186 187 def _state_callback_wrapper(): 188 with ops.init_scope(): 189 return state_callback() 190 191 self._state_callback = _state_callback_wrapper 192 self._restore_callback = restore_callback 193 with ops.device("/cpu:0"): 194 self._save_string = constant_op.constant("", dtype=dtypes.string) 195 spec = saveable_object.SaveSpec( 196 self._save_string, "", name, dtype=dtypes.string) 197 super(PythonStringStateSaveable, self).__init__(self._save_string, [spec], 198 name) 199 200 @property 201 def optional_restore(self): 202 """For values with no restore, relaxes assert_consumed().""" 203 return self._has_trivial_state_callback 204 205 def feed_dict_additions(self): 206 """When running a graph, indicates fresh state to feed.""" 207 return {self._save_string: self._state_callback()} 208 209 def freeze(self): 210 """Create a frozen `SaveableObject` which saves the current state.""" 211 212 def _constant_state(): 213 return constant_op.constant(self._state_callback(), dtype=dtypes.string) 214 215 return NoRestoreSaveable( 216 tensor=_constant_state, 217 dtype=dtypes.string, 218 name=self.name, 219 device="cpu:0") 220 221 def python_restore(self, restored_strings): 222 """Called to restore Python state.""" 223 if self._restore_callback: 224 restored, = restored_strings 225 self._restore_callback(restored) 226 227 def restore(self, restored_tensors, restored_shapes): 228 """Called to restore TensorFlow state (nothing to do).""" 229 return control_flow_ops.no_op() 230 231 232class CheckpointPosition(object): 233 """Indicates a position within a `_CheckpointRestoreCoordinator`.""" 234 235 __slots__ = ["_checkpoint", "_proto_id"] 236 237 def __init__(self, checkpoint, proto_id): 238 """Specify an object within a checkpoint. 239 240 Args: 241 checkpoint: A _CheckpointRestoreCoordinator object. 242 proto_id: The index of this object in TrackableObjectGraph.nodes. 243 """ 244 self._checkpoint = checkpoint 245 self._proto_id = proto_id 246 247 def restore(self, trackable): 248 """Restore this value into `trackable`.""" 249 with ops.init_scope(): 250 if self.bind_object(trackable): 251 # This object's correspondence with a checkpointed object is new, so 252 # process deferred restorations for it and its dependencies. 253 restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access 254 if restore_ops: 255 self._checkpoint.new_restore_ops(restore_ops) 256 257 def bind_object(self, trackable): 258 """Set a checkpoint<->object correspondence and process slot variables. 259 260 Args: 261 trackable: The object to record a correspondence for. 262 263 Returns: 264 True if this is a new assignment, False if this object has already been 265 mapped to a checkpointed `Object` proto. 266 Raises: 267 AssertionError: If another object is already bound to the `Object` proto. 268 """ 269 checkpoint = self.checkpoint 270 checkpoint.all_python_objects.add(trackable) 271 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) 272 checkpoint.matched_proto_ids.add(self._proto_id) 273 if current_assignment is None: 274 checkpoint.object_by_proto_id[self._proto_id] = trackable 275 for deferred_slot_restoration in ( 276 checkpoint.deferred_slot_restorations.pop(self._proto_id, ())): 277 trackable._create_or_restore_slot_variable( # pylint: disable=protected-access 278 slot_variable_position=CheckpointPosition( 279 checkpoint=checkpoint, 280 proto_id=deferred_slot_restoration.slot_variable_id), 281 variable=deferred_slot_restoration.original_variable, 282 slot_name=deferred_slot_restoration.slot_name) 283 for slot_restoration in checkpoint.slot_restorations.pop( 284 self._proto_id, ()): 285 optimizer_object = checkpoint.object_by_proto_id.get( 286 slot_restoration.optimizer_id, None) 287 if optimizer_object is None: 288 # The optimizer has not yet been created or tracked. Record in the 289 # checkpoint that the slot variables need to be restored when it is. 290 checkpoint.deferred_slot_restorations.setdefault( 291 slot_restoration.optimizer_id, []).append( 292 _DeferredSlotVariableRestoration( 293 original_variable=trackable, 294 slot_variable_id=slot_restoration.slot_variable_id, 295 slot_name=slot_restoration.slot_name)) 296 297 # `optimizer_object` can be a `Checkpoint` when user only needs the 298 # attributes the optimizer holds, such as `iterations`. In those cases, 299 # it would not have the optimizer's `_create_or_restore_slot_variable` 300 # method. 301 elif hasattr(optimizer_object, "_create_or_restore_slot_variable"): 302 optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access 303 slot_variable_position=CheckpointPosition( 304 checkpoint=checkpoint, 305 proto_id=slot_restoration.slot_variable_id), 306 variable=trackable, 307 slot_name=slot_restoration.slot_name) 308 return True # New assignment 309 else: 310 # The object was already mapped for this checkpoint load, which means 311 # we don't need to do anything besides check that the mapping is 312 # consistent (if the dependency DAG is not a tree then there are 313 # multiple paths to the same object). 314 if current_assignment is not trackable: 315 logging.warning(( 316 "Inconsistent references when loading the checkpoint into this " 317 "object graph. Either the Trackable object references in the " 318 "Python program have changed in an incompatible way, or the " 319 "checkpoint was generated in an incompatible program.\n\nTwo " 320 "checkpoint references resolved to different objects (%s and %s)."), 321 current_assignment, trackable) 322 return False # Not a new assignment 323 324 def is_simple_variable(self): 325 """Determine whether this value is restorable with a Tensor initializer.""" 326 attributes = self.object_proto.attributes 327 return (len(attributes) == 1 and 328 attributes[0].name == VARIABLE_VALUE_KEY and 329 not self.object_proto.children) 330 331 def value_tensors(self, shape_and_slices=None): 332 """Create value `Tensor`s for this object's attributes. 333 334 Does not require that the Python object has been created. Used for 335 restore-on-create when executing eagerly. 336 337 Args: 338 shape_and_slices: A dict mapping from object attribute names to a shape 339 and slice string that will be passed to a RestoreV2 op. If the dict is 340 None or if an object attribute is not in the dict, the full tensor will 341 be restored. 342 343 Returns: 344 A dictionary mapping from object attribute names to `Tensor`s. 345 """ 346 value_tensors = {} 347 for serialized_tensor in self.object_proto.attributes: 348 checkpoint_key = serialized_tensor.checkpoint_key 349 dtype = self._checkpoint.dtype_map[checkpoint_key] 350 base_type = dtype.base_dtype 351 io_device = self._checkpoint.options.experimental_io_device or "cpu:0" 352 with ops.init_scope(): 353 with ops.device(io_device): 354 # Run the restore itself on the io_device(CPU or specified). 355 if (shape_and_slices is not None and 356 serialized_tensor.name in shape_and_slices): 357 shape_and_slice = shape_and_slices[serialized_tensor.name] 358 else: 359 shape_and_slice = "" 360 value, = io_ops.restore_v2( 361 prefix=self._checkpoint.save_path_tensor, 362 tensor_names=[checkpoint_key], 363 shape_and_slices=[shape_and_slice], 364 dtypes=[base_type], 365 name="%s_checkpoint_read" % (serialized_tensor.name,)) 366 # Copy the value to the current device if necessary. 367 value_tensors[serialized_tensor.name] = array_ops.identity(value) 368 return value_tensors 369 370 def gather_ops_or_named_saveables(self): 371 """Looks up or creates SaveableObjects which don't have cached ops.""" 372 saveables = self.trackable._gather_saveables_for_checkpoint() # pylint: disable=protected-access 373 # Name saveables based on the name this object had when it was checkpointed. 374 named_saveables = {} 375 python_saveables = [] 376 existing_restore_ops = [] 377 for serialized_tensor in self.object_proto.attributes: 378 if context.executing_eagerly(): 379 existing_op = None 380 else: 381 existing_op = self._checkpoint.restore_ops_by_name.get( 382 serialized_tensor.checkpoint_key, None) 383 if existing_op is not None: 384 existing_restore_ops.append(existing_op) 385 continue 386 387 # Only if we don't have cached ops for this SaveableObject, we'll see if 388 # the SaveableObject itself has been cached. If not, we'll make it, and 389 # either way we'll extract new ops from it (or if it has Python state to 390 # restore, we'll run that). 391 saveables_cache = self._checkpoint.graph_view.saveables_cache 392 if saveables_cache is None: 393 # No SaveableObject caching when executing eagerly. 394 saveable = None 395 else: 396 # If we've already created and cached a SaveableObject for this 397 # attribute, we can re-use it to avoid re-creating some ops when graph 398 # building. 399 saveable_list = saveables_cache.get(self.trackable, 400 {}).get(serialized_tensor.name, 401 (None,)) 402 if len(saveable_list) == 1: 403 # Almost every attribute will have exactly one SaveableObject. 404 saveable, = saveable_list 405 else: 406 # Don't use cached SaveableObjects for partitioned variables, which is 407 # the only case where we'd have a list of SaveableObjects. Op caching 408 # will catch them. 409 saveable = None 410 if saveable is not None: 411 # The name of this attribute has changed, so we need to re-generate 412 # the SaveableObject. 413 if serialized_tensor.checkpoint_key not in saveable.name: 414 saveable = None 415 del saveables_cache[self.trackable] 416 if saveable is None: 417 # If there was no cached SaveableObject, we should check if the Python 418 # object has the attribute. 419 saveable_factory = saveables.get(serialized_tensor.name, None) 420 if saveable_factory is None: 421 # Purposefully does not throw an exception if attributes have been 422 # added or deleted. Stores unused attributes so an exception can be 423 # raised if the user decides to check that everything in the 424 # checkpoint was loaded. 425 if not serialized_tensor.optional_restore: 426 self._checkpoint.unused_attributes.setdefault( 427 self._proto_id, []).append(serialized_tensor.name) 428 continue 429 if callable(saveable_factory): 430 saveable = saveable_factory(name=serialized_tensor.checkpoint_key) 431 else: 432 saveable = saveable_factory 433 if saveables_cache is not None: 434 saveables_cache.setdefault(self.trackable, 435 {})[serialized_tensor.name] = [saveable] 436 if isinstance(saveable, PythonStateSaveable): 437 python_saveables.append(saveable) 438 else: 439 named_saveables[serialized_tensor.checkpoint_key] = saveable 440 return existing_restore_ops, named_saveables, python_saveables 441 442 def restore_ops(self): 443 """Create or fetch restore ops for this object's attributes. 444 445 Requires that the `Trackable` Python object has been bound to an object 446 ID in the checkpoint. 447 448 Returns: 449 A list of operations when graph building, or an empty list when executing 450 eagerly. 451 """ 452 (restore_ops, tensor_saveables, 453 python_saveables) = self.gather_ops_or_named_saveables() 454 restore_ops.extend( 455 self._checkpoint.restore_saveables(tensor_saveables, python_saveables)) 456 return restore_ops 457 458 @property 459 def checkpoint(self): 460 return self._checkpoint 461 462 @property 463 def trackable(self): 464 return self._checkpoint.object_by_proto_id[self._proto_id] 465 466 @property 467 def object_proto(self): 468 return self._checkpoint.object_graph_proto.nodes[self._proto_id] 469 470 @property 471 def restore_uid(self): 472 return self._checkpoint.restore_uid 473 474 def __repr__(self): 475 return repr(self.object_proto) 476 477 478_DeferredSlotVariableRestoration = collections.namedtuple( 479 "_DeferredSlotVariableRestoration", [ 480 "original_variable", 481 "slot_variable_id", 482 "slot_name", 483 ]) 484 485_SlotVariableRestoration = collections.namedtuple( 486 "_SlotVariableRestoration", 487 [ 488 # The checkpoint proto id of the optimizer object. 489 "optimizer_id", 490 # The checkpoint proto id of the slot variable. 491 "slot_variable_id", 492 "slot_name", 493 ]) 494 495 496def no_automatic_dependency_tracking(method): 497 """Disables automatic dependency tracking on attribute assignment. 498 499 Use to decorate any method of a Trackable object. Attribute assignment in 500 that method will not add dependencies (also respected in Model). Harmless if 501 used in a class which does not do automatic dependency tracking (which means 502 it's safe to use in base classes which may have subclasses which also inherit 503 from Trackable). 504 505 Args: 506 method: The method to decorate. 507 508 Returns: 509 A decorated method which sets and un-sets automatic dependency tracking for 510 the object the method is called on (not thread safe). 511 """ 512 513 def _method_wrapper(self, *args, **kwargs): 514 previous_value = getattr(self, "_self_setattr_tracking", True) 515 self._self_setattr_tracking = False # pylint: disable=protected-access 516 try: 517 result = method(self, *args, **kwargs) 518 finally: 519 self._self_setattr_tracking = previous_value # pylint: disable=protected-access 520 return result 521 522 return tf_decorator.make_decorator( 523 target=method, decorator_func=_method_wrapper) 524 525 526@tf_contextlib.contextmanager 527def no_manual_dependency_tracking_scope(obj): 528 """A context that disables manual dependency tracking for the given `obj`. 529 530 Sometimes library methods might track objects on their own and we might want 531 to disable that and do the tracking on our own. One can then use this context 532 manager to disable the tracking the library method does and do your own 533 tracking. 534 535 For example: 536 537 class TestLayer(tf.keras.Layer): 538 def build(): 539 with no_manual_dependency_tracking_scope(self): 540 var = self.add_variable("name1") # Creates a var and doesn't track it 541 self._track_trackable("name2", var) # We track variable with name `name2` 542 543 Args: 544 obj: A trackable object. 545 546 Yields: 547 a scope in which the object doesn't track dependencies manually. 548 """ 549 # pylint: disable=protected-access 550 previous_value = getattr(obj, "_manual_tracking", True) 551 obj._manual_tracking = False 552 try: 553 yield 554 finally: 555 obj._manual_tracking = previous_value 556 557 558@tf_contextlib.contextmanager 559def no_automatic_dependency_tracking_scope(obj): 560 """A context that disables automatic dependency tracking when assigning attrs. 561 562 Objects that inherit from Autotrackable automatically creates dependencies 563 to trackable objects through attribute assignments, and wraps data structures 564 (lists or dicts) with trackable classes. This scope may be used to temporarily 565 disable this behavior. This works similar to the decorator 566 `no_automatic_dependency_tracking`. 567 568 Example usage: 569 ``` 570 model = tf.keras.Model() 571 model.arr1 = [] # Creates a ListWrapper object 572 with no_automatic_dependency_tracking_scope(model): 573 model.arr2 = [] # Creates a regular, untracked python list 574 ``` 575 576 Args: 577 obj: A trackable object. 578 579 Yields: 580 a scope in which the object doesn't track dependencies. 581 """ 582 previous_value = getattr(obj, "_setattr_tracking", True) 583 obj._setattr_tracking = False # pylint: disable=protected-access 584 try: 585 yield 586 finally: 587 obj._setattr_tracking = previous_value # pylint: disable=protected-access 588 589 590@tf_export("__internal__.tracking.Trackable", v1=[]) 591class Trackable(object): 592 """Base class for `Trackable` objects without automatic dependencies. 593 594 This class has no __setattr__ override for performance reasons. Dependencies 595 must be added explicitly. Unless attribute assignment is performance-critical, 596 use `AutoTrackable` instead. Use `Trackable` for `isinstance` 597 checks. 598 """ 599 600 # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with 601 # _self_. We have some properties to forward semi-public attributes to their 602 # _self_ equivalents. 603 604 @property 605 def _setattr_tracking(self): 606 if not hasattr(self, "_self_setattr_tracking"): 607 self._self_setattr_tracking = True 608 return self._self_setattr_tracking 609 610 @_setattr_tracking.setter 611 def _setattr_tracking(self, value): 612 self._self_setattr_tracking = value 613 614 @property 615 def _update_uid(self): 616 return self._self_update_uid 617 618 @_update_uid.setter 619 def _update_uid(self, value): 620 self._self_update_uid = value 621 622 @property 623 def _unconditional_checkpoint_dependencies(self): 624 return self._self_unconditional_checkpoint_dependencies 625 626 @property 627 def _unconditional_dependency_names(self): 628 return self._self_unconditional_dependency_names 629 630 @property 631 def _name_based_restores(self): 632 return self._self_name_based_restores 633 634 # Trackable does not do automatic dependency tracking, but uses the 635 # no_automatic_dependency_tracking decorator so it can avoid adding 636 # dependencies if a subclass is Trackable / inherits from Model (both of 637 # which have __setattr__ overrides). 638 @no_automatic_dependency_tracking 639 def _maybe_initialize_trackable(self): 640 """Initialize dependency management. 641 642 Not __init__, since most objects will forget to call it. 643 """ 644 if hasattr(self, "_self_unconditional_checkpoint_dependencies"): 645 # __init__ already called. This check means that we don't need 646 # Trackable.__init__() in the constructor of every TensorFlow object. 647 return 648 # A list of TrackableReference objects. Some classes implementing 649 # `Trackable`, notably `Optimizer`s, may override the 650 # _checkpoint_dependencies property with conditional dependencies 651 # (e.g. based on the current graph when saving). 652 self._self_unconditional_checkpoint_dependencies = [] 653 # Maps names -> Trackable objects 654 self._self_unconditional_dependency_names = {} 655 # Restorations for other Trackable objects on which this object may 656 # eventually depend. Maps local name -> CheckpointPosition list. Optimizers 657 # tack on conditional dependencies, and so need separate management of 658 # deferred dependencies too. 659 self._self_unconditional_deferred_dependencies = {} 660 # The UID of the highest assignment to this object. Used to ensure that the 661 # last requested assignment determines the final value of an object. 662 if hasattr(self, "_self_update_uid"): 663 raise AssertionError( 664 "Internal error: the object had an update UID set before its " 665 "initialization code was run.") 666 self._self_update_uid = -1 667 # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator 668 # instances, which should be checked when creating variables or other 669 # saveables. These are passed on recursively to all dependencies, since 670 # unlike object-based checkpoint restores we don't know which subgraph is 671 # being restored in advance. This mechanism is only necessary for 672 # restore-on-create when executing eagerly, and so is unused when graph 673 # building. 674 self._self_name_based_restores = set() 675 676 # Dictionary of SaveableObjects factories. This dictionary is defined when 677 # the object is loaded from the SavedModel. When writing a custom class, 678 # prefer overriding "_gather_saveables_from_checkpoint" to using this 679 # attribute. 680 self._self_saveable_object_factories = {} 681 682 @property 683 def _object_identifier(self): 684 """String used to identify this object in a SavedModel. 685 686 Generally, the object identifier is constant across objects of the same 687 class, while the metadata field is used for instance-specific data. 688 689 Returns: 690 String object identifier. 691 """ 692 return "_generic_user_object" 693 694 @property 695 def _tracking_metadata(self): 696 """String containing object metadata, which is saved to the SavedModel.""" 697 return "" 698 699 def _no_dependency(self, value): 700 """If automatic dependency tracking is enabled, ignores `value`.""" 701 return value 702 703 def _name_based_attribute_restore(self, checkpoint): 704 """Restore the object's attributes from a name-based checkpoint.""" 705 self._self_name_based_restores.add(checkpoint) 706 if self._self_update_uid < checkpoint.restore_uid: 707 checkpoint.eager_restore(self) 708 self._self_update_uid = checkpoint.restore_uid 709 710 @property 711 def _checkpoint_dependencies(self): 712 """All dependencies of this object. 713 714 May be overridden to include conditional dependencies. 715 716 Returns: 717 A list of `TrackableReference` objects indicating named 718 `Trackable` dependencies which should be saved along with this 719 object. 720 """ 721 return self._self_unconditional_checkpoint_dependencies 722 723 @property 724 def _deferred_dependencies(self): 725 """A dictionary with deferred dependencies. 726 727 Stores restorations for other Trackable objects on which this object 728 may eventually depend. May be overridden by sub-classes (e.g. Optimizers use 729 conditional dependencies based the current graph, and so need separate 730 management of deferred dependencies too). 731 732 Returns: 733 A dictionary mapping from local name to a list of CheckpointPosition 734 objects. 735 """ 736 return self._self_unconditional_deferred_dependencies 737 738 def _lookup_dependency(self, name): 739 """Look up a dependency by name. 740 741 May be overridden to include conditional dependencies. 742 743 Args: 744 name: The local name of the dependency. 745 746 Returns: 747 A `Trackable` object, or `None` if no dependency by this name was 748 found. 749 """ 750 return self._self_unconditional_dependency_names.get(name, None) 751 752 def _add_variable_with_custom_getter(self, 753 name, 754 shape=None, 755 dtype=dtypes.float32, 756 initializer=None, 757 getter=None, 758 overwrite=False, 759 **kwargs_for_getter): 760 """Restore-on-create for a variable be saved with this `Trackable`. 761 762 If the user has requested that this object or another `Trackable` which 763 depends on this object be restored from a checkpoint (deferred loading 764 before variable object creation), `initializer` may be ignored and the value 765 from the checkpoint used instead. 766 767 Args: 768 name: A name for the variable. Must be unique within this object. 769 shape: The shape of the variable. 770 dtype: The data type of the variable. 771 initializer: The initializer to use. Ignored if there is a deferred 772 restoration left over from a call to 773 `_restore_from_checkpoint_position`. 774 getter: The getter to wrap which actually fetches the variable. 775 overwrite: If True, disables unique name and type checks. 776 **kwargs_for_getter: Passed to the getter. 777 778 Returns: 779 The new variable object. 780 781 Raises: 782 ValueError: If the variable name is not unique. 783 """ 784 self._maybe_initialize_trackable() 785 with ops.init_scope(): 786 if context.executing_eagerly(): 787 # If this is a variable with a single Tensor stored in the checkpoint, 788 # we can set that value as an initializer rather than initializing and 789 # then assigning (when executing eagerly). This call returns None if 790 # there is nothing to restore. 791 checkpoint_initializer = self._preload_simple_restoration( 792 name=name) 793 else: 794 checkpoint_initializer = None 795 if (checkpoint_initializer is not None and 796 not (isinstance(initializer, CheckpointInitialValueCallable) and 797 (initializer.restore_uid > checkpoint_initializer.restore_uid))): 798 # If multiple Trackable objects are "creating" the same variable 799 # via the magic of custom getters, the one with the highest restore UID 800 # (the one called last) has to make the final initializer. If another 801 # custom getter interrupts this process by overwriting the initializer, 802 # then we'll catch that when we call _track_trackable. So this is 803 # "best effort" to set the initializer with the highest restore UID. 804 initializer = checkpoint_initializer 805 new_variable = getter( 806 name=name, 807 shape=shape, 808 dtype=dtype, 809 initializer=initializer, 810 **kwargs_for_getter) 811 812 # If we set an initializer and the variable processed it, tracking will not 813 # assign again. It will add this variable to our dependencies, and if there 814 # is a non-trivial restoration queued, it will handle that. This also 815 # handles slot variables. 816 if not overwrite or isinstance(new_variable, Trackable): 817 return self._track_trackable(new_variable, name=name, overwrite=overwrite) 818 else: 819 # TODO(allenl): Some variable types are not yet supported. Remove this 820 # fallback once all get_variable() return types are Trackable. 821 return new_variable 822 823 def _preload_simple_restoration(self, name): 824 """Return a dependency's value for restore-on-create. 825 826 Note the restoration is not deleted; if for some reason preload is called 827 and then not assigned to the variable (for example because a custom getter 828 overrides the initializer), the assignment will still happen once the 829 variable is tracked (determined based on checkpoint.restore_uid). 830 831 Args: 832 name: The object-local name of the dependency holding the variable's 833 value. 834 835 Returns: 836 An callable for use as a variable's initializer/initial_value, or None if 837 one should not be set (either because there was no variable with this name 838 in the checkpoint or because it needs more complex deserialization). Any 839 non-trivial deserialization will happen when the variable object is 840 tracked. 841 """ 842 deferred_dependencies_list = self._deferred_dependencies.get(name, ()) 843 if not deferred_dependencies_list: 844 # Nothing to do; we don't have a restore for this dependency queued up. 845 return 846 for checkpoint_position in deferred_dependencies_list: 847 if not checkpoint_position.is_simple_variable(): 848 # If _any_ pending restoration is too complicated to fit in an 849 # initializer (because it has dependencies, or because there are 850 # multiple Tensors to restore), bail and let the general tracking code 851 # handle it. 852 return None 853 checkpoint_position = max( 854 deferred_dependencies_list, 855 key=lambda restore: restore.checkpoint.restore_uid) 856 return CheckpointInitialValueCallable( 857 checkpoint_position=checkpoint_position) 858 859 def _track_trackable(self, trackable, name, overwrite=False): 860 """Declare a dependency on another `Trackable` object. 861 862 Indicates that checkpoints for this object should include variables from 863 `trackable`. 864 865 Variables in a checkpoint are mapped to `Trackable`s based on the names 866 provided when the checkpoint was written. To avoid breaking existing 867 checkpoints when modifying a class, neither variable names nor dependency 868 names (the names passed to `_track_trackable`) may change. 869 870 Args: 871 trackable: A `Trackable` which this object depends on. 872 name: A local name for `trackable`, used for loading checkpoints into the 873 correct objects. 874 overwrite: Boolean, whether silently replacing dependencies is OK. Used 875 for __setattr__, where throwing an error on attribute reassignment would 876 be inappropriate. 877 878 Returns: 879 `trackable`, for convenience when declaring a dependency and 880 assigning to a member variable in one statement. 881 882 Raises: 883 TypeError: If `trackable` does not inherit from `Trackable`. 884 ValueError: If another object is already tracked by this name. 885 """ 886 self._maybe_initialize_trackable() 887 if not isinstance(trackable, Trackable): 888 raise TypeError(("Trackable._track_trackable() passed type %s, not a " 889 "Trackable.") % (type(trackable),)) 890 if not getattr(self, "_manual_tracking", True): 891 return trackable 892 new_reference = TrackableReference(name=name, ref=trackable) 893 current_object = self._lookup_dependency(name) 894 if (current_object is not None and current_object is not trackable): 895 if not overwrite: 896 raise ValueError( 897 ("Called Trackable._track_trackable() with name='%s', " 898 "but a Trackable with this name is already declared as a " 899 "dependency. Names must be unique (or overwrite=True).") % (name,)) 900 # This is a weird thing to do, but we're not going to stop people from 901 # using __setattr__. 902 for index, (old_name, _) in enumerate( 903 self._self_unconditional_checkpoint_dependencies): 904 if name == old_name: 905 self._self_unconditional_checkpoint_dependencies[ 906 index] = new_reference 907 elif current_object is None: 908 self._self_unconditional_checkpoint_dependencies.append(new_reference) 909 self._handle_deferred_dependencies(name=name, trackable=trackable) 910 self._self_unconditional_dependency_names[name] = trackable 911 return trackable 912 913 def _handle_deferred_dependencies(self, name, trackable): 914 """Pop and load any deferred checkpoint restores into `trackable`. 915 916 This method does not add a new dependency on `trackable`, but it does 917 check if any outstanding/deferred dependencies have been queued waiting for 918 this dependency to be added (matched based on `name`). If so, 919 `trackable` and its dependencies are restored. The restorations are 920 considered fulfilled and so are deleted. 921 922 `_track_trackable` is more appropriate for adding a 923 normal/unconditional dependency, and includes handling for deferred 924 restorations. This method allows objects such as `Optimizer` to use the same 925 restoration logic while managing conditional dependencies themselves, by 926 overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the 927 object's dependencies based on the context it is saved/restored in (a single 928 optimizer instance can have state associated with multiple graphs). 929 930 Args: 931 name: The name of the dependency within this object (`self`), used to 932 match `trackable` with values saved in a checkpoint. 933 trackable: The Trackable object to restore (inheriting from `Trackable`). 934 """ 935 self._maybe_initialize_trackable() 936 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 937 deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) 938 for checkpoint_position in sorted( 939 deferred_dependencies_list, 940 key=lambda restore: restore.checkpoint.restore_uid, 941 reverse=True): 942 checkpoint_position.restore(trackable) 943 944 # Pass on any name-based restores queued in this object. 945 for name_based_restore in sorted( 946 self._self_name_based_restores, 947 key=lambda checkpoint: checkpoint.restore_uid, 948 reverse=True): 949 trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access 950 951 def _restore_from_checkpoint_position(self, checkpoint_position): 952 """Restore this object and its dependencies (may be deferred).""" 953 # Attempt a breadth-first traversal, since presumably the user has more 954 # control over shorter paths. If we don't have all of the dependencies at 955 # this point, the end result is not breadth-first (since other deferred 956 # traversals will happen later). 957 visit_queue = collections.deque([checkpoint_position]) 958 restore_ops = [] 959 tensor_saveables = {} 960 python_saveables = [] 961 while visit_queue: 962 current_position = visit_queue.popleft() 963 new_restore_ops, new_tensor_saveables, new_python_saveables = ( 964 current_position.trackable # pylint: disable=protected-access 965 ._single_restoration_from_checkpoint_position( 966 checkpoint_position=current_position, 967 visit_queue=visit_queue)) 968 restore_ops.extend(new_restore_ops) 969 tensor_saveables.update(new_tensor_saveables) 970 python_saveables.extend(new_python_saveables) 971 restore_ops.extend( 972 current_position.checkpoint.restore_saveables( 973 tensor_saveables, python_saveables)) 974 return restore_ops 975 976 def _single_restoration_from_checkpoint_position(self, checkpoint_position, 977 visit_queue): 978 """Restore this object, and either queue its dependencies or defer them.""" 979 self._maybe_initialize_trackable() 980 checkpoint = checkpoint_position.checkpoint 981 # If the UID of this restore is lower than our current update UID, we don't 982 # need to actually restore the object. However, we should pass the 983 # restoration on to our dependencies. 984 if checkpoint.restore_uid > self._self_update_uid: 985 restore_ops, tensor_saveables, python_saveables = ( 986 checkpoint_position.gather_ops_or_named_saveables()) 987 self._self_update_uid = checkpoint.restore_uid 988 else: 989 restore_ops = () 990 tensor_saveables = {} 991 python_saveables = () 992 for child in checkpoint_position.object_proto.children: 993 child_position = CheckpointPosition( 994 checkpoint=checkpoint, proto_id=child.node_id) 995 local_object = self._lookup_dependency(child.local_name) 996 if local_object is None: 997 # We don't yet have a dependency registered with this name. Save it 998 # in case we do. 999 self._deferred_dependencies.setdefault(child.local_name, 1000 []).append(child_position) 1001 else: 1002 if child_position.bind_object(trackable=local_object): 1003 # This object's correspondence is new, so dependencies need to be 1004 # visited. Delay doing it so that we get a breadth-first dependency 1005 # resolution order (shallowest paths first). The caller is responsible 1006 # for emptying visit_queue. 1007 visit_queue.append(child_position) 1008 return restore_ops, tensor_saveables, python_saveables 1009 1010 def _gather_saveables_for_checkpoint(self): 1011 """Returns a dictionary of values to checkpoint with this object. 1012 1013 Keys in the returned dictionary are local to this object and in a separate 1014 namespace from dependencies. Values may either be `SaveableObject` factories 1015 or variables easily converted to `SaveableObject`s (as in 1016 `tf.compat.v1.train.Saver`'s 1017 `var_list` constructor argument). 1018 1019 `SaveableObjects` have a name set, which Trackable needs to generate 1020 itself. So rather than returning `SaveableObjects` directly, this method 1021 should return a dictionary of callables which take `name` arguments and 1022 return `SaveableObjects` with that name. 1023 1024 If this object may also be passed to the global-name-based 1025 `tf.compat.v1.train.Saver`, 1026 the returned callables should have a default value for their name argument 1027 (i.e. be callable with no arguments). 1028 1029 Returned values must be saved only by this object; if any value may be 1030 shared, it should instead be a dependency. For example, variable objects 1031 save their own values with the key `VARIABLE_VALUE_KEY`, but objects which 1032 reference variables simply add a dependency. 1033 1034 Returns: 1035 The dictionary mapping attribute names to `SaveableObject` factories 1036 described above. For example: 1037 {VARIABLE_VALUE_KEY: 1038 lambda name="global_name_for_this_object": 1039 SaveableObject(name=name, ...)} 1040 """ 1041 return self._self_saveable_object_factories 1042 1043 def _list_extra_dependencies_for_serialization(self, serialization_cache): 1044 """Lists extra dependencies to serialize. 1045 1046 Internal sub-classes can override this method to return extra dependencies 1047 that should be saved with the object during SavedModel serialization. For 1048 example, this is used to save `trainable_variables` in Keras models. The 1049 python property `trainable_variables` contains logic to iterate through the 1050 weights from the model and its sublayers. The serialized Keras model saves 1051 `trainable_weights` as a trackable list of variables. 1052 1053 PLEASE NOTE when overriding this method: 1054 1. This function may only generate new trackable objects the first time it 1055 is called. 1056 2. The returned dictionary must not have naming conflicts with 1057 dependencies tracked by the root. In other words, if the root is 1058 tracking `object_1` with name 'x', and this functions returns 1059 `{'x': object_2}`, an error is raised when saving. 1060 1061 Args: 1062 serialization_cache: A dictionary shared between all objects in the same 1063 object graph. This object is passed to both 1064 `_list_extra_dependencies_for_serialization` and 1065 `_list_functions_for_serialization`. 1066 1067 Returns: 1068 A dictionary mapping attribute names to trackable objects. 1069 """ 1070 del serialization_cache 1071 return dict() 1072 1073 def _list_functions_for_serialization(self, serialization_cache): 1074 """Lists the functions of this trackable to serialize. 1075 1076 Internal sub-classes can override this with specific logic. E.g. 1077 `AutoTrackable` provides an implementation that returns the `attr` 1078 that return functions. 1079 1080 Args: 1081 serialization_cache: Dictionary passed to all objects in the same object 1082 graph during serialization. 1083 1084 Returns: 1085 A dictionary mapping attribute names to `Function` or 1086 `ConcreteFunction`. 1087 """ 1088 del serialization_cache 1089 return dict() 1090 1091 def _map_resources(self, save_options): # pylint: disable=unused-argument 1092 """Makes new resource handle ops corresponding to existing resource tensors. 1093 1094 Internal sub-classes can override this to inform model saving how to add new 1095 resource handle ops to the main GraphDef of a SavedModel (TF 1.x style 1096 graph), which allows session based APIs (e.g, C++ loader API) to interact 1097 with resources owned by this object. 1098 1099 Args: 1100 save_options: A tf.saved_model.SaveOptions instance. 1101 1102 Returns: 1103 A tuple of (object_map, resource_map): 1104 object_map: A dictionary mapping from objects that hold existing 1105 resource tensors to replacement objects created to hold the new 1106 resource tensors. 1107 resource_map: A dictionary mapping from existing resource tensors to 1108 newly created resource tensors. 1109 """ 1110 return {}, {} 1111