1"""An object-local variable management scheme.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22import functools 23import json 24import weakref 25 26import six 27 28from tensorflow.python.eager import context 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import gen_io_ops as io_ops 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.training.saving import saveable_object 37from tensorflow.python.util import nest 38from tensorflow.python.util import serialization 39from tensorflow.python.util import tf_decorator 40 41 42# Key where the object graph proto is saved in a TensorBundle 43OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" 44 45 46# A key indicating a variable's value in an object's checkpointed Tensors 47# (Trackable._gather_saveables_for_checkpoint). If this is the only key and 48# the object has no dependencies, then its value may be restored on object 49# creation (avoiding double assignment when executing eagerly). 50VARIABLE_VALUE_KEY = "VARIABLE_VALUE" 51OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON" 52 53TrackableReference = collections.namedtuple( 54 "TrackableReference", 55 [ 56 # The local name for this dependency. 57 "name", 58 # The Trackable object being referenced. 59 "ref" 60 ]) 61 62 63class CheckpointInitialValue(ops.Tensor): 64 """Tensor wrapper for managing update UIDs in `Variables`. 65 66 When supplied as an initial value, objects of this type let a `Variable` 67 (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial 68 value came from. This allows deferred restorations to be sequenced in the 69 order the user specified them, and lets us fall back on assignment if an 70 initial value is not set (e.g. due to a custom getter interfering). 71 72 See comments in _add_variable_with_custom_getter for more information about 73 how `CheckpointInitialValue` is used. 74 """ 75 76 def __init__(self, checkpoint_position, shape=None): 77 self.wrapped_value = checkpoint_position.value_tensors()[ 78 VARIABLE_VALUE_KEY] 79 if shape: 80 # We need to set the static shape information on the initializer if 81 # possible so we don't get a variable with an unknown shape. 82 self.wrapped_value.set_shape(shape) 83 self._checkpoint_position = checkpoint_position 84 85 def __getattr__(self, attr): 86 try: 87 return getattr(self.wrapped_value, attr) 88 except AttributeError: 89 return self.__getattribute__(attr) 90 91 @property 92 def checkpoint_position(self): 93 return self._checkpoint_position 94 95 96class NoRestoreSaveable(saveable_object.SaveableObject): 97 """Embeds a tensor in a checkpoint with no restore ops.""" 98 99 def __init__(self, tensor, name, dtype=None): 100 spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype) 101 super(NoRestoreSaveable, self).__init__(tensor, [spec], name) 102 103 def restore(self, restored_tensors, restored_shapes): 104 return control_flow_ops.no_op() 105 106 107@six.add_metaclass(abc.ABCMeta) 108class PythonStateSaveable(saveable_object.SaveableObject): 109 """An interface for saving/restoring volatile Python state.""" 110 111 @abc.abstractmethod 112 def feed_dict_additions(self): 113 """When running a graph, indicates fresh state to feed. 114 115 Returns: 116 A dictionary mapping `Tensor`s to current Python state. 117 """ 118 pass 119 120 @abc.abstractmethod 121 def freeze(self): 122 """Create a new `SaveableObject` which freezes current state as a constant. 123 124 Used when executing eagerly to embed the current state as a constant, or 125 when creating a static tf.train.Saver with the frozen current Python state. 126 127 Returns: 128 A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has 129 no Python state associated with it). 130 """ 131 pass 132 133 134class PythonStringStateSaveable(PythonStateSaveable): 135 """Saves Python state in a checkpoint.""" 136 137 def __init__(self, name, state_callback, restore_callback=None): 138 """Configure saving. 139 140 Args: 141 name: The checkpoint key to write to. 142 state_callback: A function taking no arguments which returns a 143 string. This function is run every time a checkpoint is written. 144 restore_callback: A function taking a Python string, used to restore 145 state. Optional; defaults to doing nothing, in which case it is ignored 146 by status assertions such as assert_consumed(). 147 """ 148 self._has_trivial_state_callback = (restore_callback is None) 149 def _state_callback_wrapper(): 150 with ops.init_scope(): 151 return state_callback() 152 self._state_callback = _state_callback_wrapper 153 self._restore_callback = restore_callback 154 with ops.device("/cpu:0"): 155 self._save_string = constant_op.constant("", dtype=dtypes.string) 156 spec = saveable_object.SaveSpec( 157 self._save_string, "", name, dtype=dtypes.string) 158 super(PythonStringStateSaveable, self).__init__( 159 self._save_string, [spec], name) 160 161 @property 162 def optional_restore(self): 163 """For values with no restore, relaxes assert_consumed().""" 164 return self._has_trivial_state_callback 165 166 def feed_dict_additions(self): 167 """When running a graph, indicates fresh state to feed.""" 168 return {self._save_string: self._state_callback()} 169 170 def freeze(self): 171 """Create a frozen `SaveableObject` which saves the current state.""" 172 def _constant_state(): 173 return constant_op.constant(self._state_callback(), dtype=dtypes.string) 174 return NoRestoreSaveable( 175 tensor=_constant_state, 176 dtype=dtypes.string, 177 name=self.name) 178 179 def python_restore(self, restored_strings): 180 """Called to restore Python state.""" 181 if self._restore_callback: 182 restored, = restored_strings 183 self._restore_callback(restored) 184 185 def restore(self, restored_tensors, restored_shapes): 186 """Called to restore TensorFlow state (nothing to do).""" 187 return control_flow_ops.no_op() 188 189 190class CheckpointPosition(object): 191 """Indicates a position within a `_CheckpointRestoreCoordinator`.""" 192 193 def __init__(self, checkpoint, proto_id): 194 """Specify an object within a checkpoint. 195 196 Args: 197 checkpoint: A _CheckpointRestoreCoordinator object. 198 proto_id: The index of this object in TrackableObjectGraph.nodes. 199 """ 200 self._checkpoint = checkpoint 201 self._proto_id = proto_id 202 203 def restore(self, trackable): 204 """Restore this value into `trackable`.""" 205 with ops.init_scope(): 206 if self.bind_object(trackable): 207 # This object's correspondence with a checkpointed object is new, so 208 # process deferred restorations for it and its dependencies. 209 restore_ops = trackable._restore_from_checkpoint_position(self) # pylint: disable=protected-access 210 if restore_ops: 211 self._checkpoint.new_restore_ops(restore_ops) 212 213 def bind_object(self, trackable): 214 """Set a checkpoint<->object correspondence and process slot variables. 215 216 Args: 217 trackable: The object to record a correspondence for. 218 Returns: 219 True if this is a new assignment, False if this object has already been 220 mapped to a checkpointed `Object` proto. 221 Raises: 222 AssertionError: If another object is already bound to the `Object` proto. 223 """ 224 checkpoint = self.checkpoint 225 checkpoint.all_python_objects.add(trackable) 226 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) 227 if current_assignment is None: 228 checkpoint.object_by_proto_id[self._proto_id] = trackable 229 for deferred_slot_restoration in ( 230 checkpoint.deferred_slot_restorations.pop(self._proto_id, ())): 231 trackable._create_or_restore_slot_variable( # pylint: disable=protected-access 232 slot_variable_position=CheckpointPosition( 233 checkpoint=checkpoint, 234 proto_id=deferred_slot_restoration.slot_variable_id), 235 variable=deferred_slot_restoration.original_variable, 236 slot_name=deferred_slot_restoration.slot_name) 237 for slot_restoration in checkpoint.slot_restorations.pop( 238 self._proto_id, ()): 239 optimizer_object = checkpoint.object_by_proto_id.get( 240 slot_restoration.optimizer_id, None) 241 if optimizer_object is None: 242 # The optimizer has not yet been created or tracked. Record in the 243 # checkpoint that the slot variables need to be restored when it is. 244 checkpoint.deferred_slot_restorations.setdefault( 245 slot_restoration.optimizer_id, []).append( 246 _DeferredSlotVariableRestoration( 247 original_variable=trackable, 248 slot_variable_id=slot_restoration.slot_variable_id, 249 slot_name=slot_restoration.slot_name)) 250 else: 251 optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access 252 slot_variable_position=CheckpointPosition( 253 checkpoint=checkpoint, 254 proto_id=slot_restoration.slot_variable_id), 255 variable=trackable, 256 slot_name=slot_restoration.slot_name) 257 return True # New assignment 258 else: 259 # The object was already mapped for this checkpoint load, which means 260 # we don't need to do anything besides check that the mapping is 261 # consistent (if the dependency DAG is not a tree then there are 262 # multiple paths to the same object). 263 if current_assignment is not trackable: 264 logging.warning( 265 ("Inconsistent references when loading the checkpoint into this " 266 "object graph. Either the Trackable object references in the " 267 "Python program have changed in an incompatible way, or the " 268 "checkpoint was generated in an incompatible program.\n\nTwo " 269 "checkpoint references resolved to different objects (%s and %s).") 270 % (current_assignment, trackable)) 271 return False # Not a new assignment 272 273 def is_simple_variable(self): 274 """Determine whether this value is restorable with a Tensor initializer.""" 275 attributes = self.object_proto.attributes 276 return (len(attributes) == 1 277 and attributes[0].name == VARIABLE_VALUE_KEY 278 and not self.object_proto.children) 279 280 def value_tensors(self): 281 """Create value `Tensor`s for this object's attributes. 282 283 Does not require that the Python object has been created. Used for 284 restore-on-create when executing eagerly. 285 286 Returns: 287 A dictionary mapping from object attribute names to `Tensor`s. 288 """ 289 value_tensors = {} 290 for serialized_tensor in self.object_proto.attributes: 291 checkpoint_key = serialized_tensor.checkpoint_key 292 dtype = self._checkpoint.dtype_map[checkpoint_key] 293 base_type = dtype.base_dtype 294 with ops.init_scope(): 295 with ops.device("/cpu:0"): 296 # Run the restore itself on the CPU. 297 value, = io_ops.restore_v2( 298 prefix=self._checkpoint.save_path_tensor, 299 tensor_names=[checkpoint_key], 300 shape_and_slices=[""], 301 dtypes=[base_type], 302 name="%s_checkpoint_read" % (serialized_tensor.name,)) 303 # Copy the value to the current device if necessary. 304 value_tensors[serialized_tensor.name] = array_ops.identity(value) 305 return value_tensors 306 307 def _gather_ops_or_named_saveables(self): 308 """Looks up or creates SaveableObjects which don't have cached ops.""" 309 saveables = self.trackable._gather_saveables_for_checkpoint() # pylint: disable=protected-access 310 # Name saveables based on the name this object had when it was checkpointed. 311 named_saveables = {} 312 python_saveables = [] 313 existing_restore_ops = [] 314 for serialized_tensor in self.object_proto.attributes: 315 if context.executing_eagerly(): 316 existing_op = None 317 else: 318 existing_op = self._checkpoint.restore_ops_by_name.get( 319 serialized_tensor.checkpoint_key, None) 320 if existing_op is not None: 321 existing_restore_ops.append(existing_op) 322 continue 323 324 # Only if we don't have cached ops for this SaveableObject, we'll see if 325 # the SaveableObject itself has been cached. If not, we'll make it, and 326 # either way we'll extract new ops from it (or if it has Python state to 327 # restore, we'll run that). 328 saveables_cache = self._checkpoint.graph_view.saveables_cache 329 if saveables_cache is None: 330 # No SaveableObject caching when executing eagerly. 331 saveable = None 332 else: 333 # If we've already created and cached a SaveableObject for this 334 # attribute, we can re-use it to avoid re-creating some ops when graph 335 # building. 336 saveable_list = saveables_cache.get( 337 self.trackable, {}).get(serialized_tensor.name, (None,)) 338 if len(saveable_list) == 1: 339 # Almost every attribute will have exactly one SaveableObject. 340 saveable, = saveable_list 341 else: 342 # Don't use cached SaveableObjects for partitioned variables, which is 343 # the only case where we'd have a list of SaveableObjects. Op caching 344 # will catch them. 345 saveable = None 346 if saveable is not None: 347 # The name of this attribute has changed, so we need to re-generate 348 # the SaveableObject. 349 if serialized_tensor.checkpoint_key not in saveable.name: 350 saveable = None 351 del saveables_cache[self.trackable] 352 break 353 if saveable is None: 354 # If there was no cached SaveableObject, we should check if the Python 355 # object has the attribute. 356 saveable_factory = saveables.get(serialized_tensor.name, None) 357 if saveable_factory is None: 358 # Purposefully does not throw an exception if attributes have been 359 # added or deleted. Stores unused attributes so an exception can be 360 # raised if the user decides to check that everything in the 361 # checkpoint was loaded. 362 if not serialized_tensor.optional_restore: 363 self._checkpoint.unused_attributes.setdefault( 364 self.trackable, []).append(serialized_tensor.name) 365 continue 366 if callable(saveable_factory): 367 saveable = saveable_factory(name=serialized_tensor.checkpoint_key) 368 else: 369 saveable = saveable_factory 370 if saveables_cache is not None: 371 saveables_cache.setdefault( 372 self.trackable, {})[serialized_tensor.name] = [saveable] 373 if isinstance(saveable, PythonStateSaveable): 374 python_saveables.append(saveable) 375 else: 376 named_saveables[serialized_tensor.checkpoint_key] = saveable 377 return existing_restore_ops, named_saveables, python_saveables 378 379 def restore_ops(self): 380 """Create or fetch restore ops for this object's attributes. 381 382 Requires that the `Trackable` Python object has been bound to an object 383 ID in the checkpoint. 384 385 Returns: 386 A list of operations when graph building, or an empty list when executing 387 eagerly. 388 """ 389 (restore_ops, 390 tensor_saveables, 391 python_saveables) = self._gather_ops_or_named_saveables() 392 restore_ops.extend(self._checkpoint.restore_saveables( 393 tensor_saveables, python_saveables)) 394 return restore_ops 395 396 @property 397 def checkpoint(self): 398 return self._checkpoint 399 400 @property 401 def trackable(self): 402 return self._checkpoint.object_by_proto_id[self._proto_id] 403 404 @property 405 def object_proto(self): 406 return self._checkpoint.object_graph_proto.nodes[self._proto_id] 407 408 @property 409 def restore_uid(self): 410 return self._checkpoint.restore_uid 411 412 def __repr__(self): 413 return repr(self.object_proto) 414 415 416_DeferredSlotVariableRestoration = collections.namedtuple( 417 "_DeferredSlotVariableRestoration", 418 [ 419 "original_variable", 420 "slot_variable_id", 421 "slot_name", 422 ] 423) 424 425_SlotVariableRestoration = collections.namedtuple( 426 "_SlotVariableRestoration", 427 [ 428 # The checkpoint proto id of the optimizer object. 429 "optimizer_id", 430 # The checkpoint proto id of the slot variable. 431 "slot_variable_id", 432 "slot_name", 433 ]) 434 435 436def no_automatic_dependency_tracking(method): 437 """Disables automatic dependency tracking on attribute assignment. 438 439 Use to decorate any method of a Trackable object. Attribute assignment in 440 that method will not add dependencies (also respected in Model). Harmless if 441 used in a class which does not do automatic dependency tracking (which means 442 it's safe to use in base classes which may have subclasses which also inherit 443 from Trackable). 444 445 Args: 446 method: The method to decorate. 447 Returns: 448 A decorated method which sets and un-sets automatic dependency tracking for 449 the object the method is called on (not thread safe). 450 """ 451 452 def _method_wrapper(self, *args, **kwargs): 453 previous_value = getattr(self, "_setattr_tracking", True) 454 self._setattr_tracking = False # pylint: disable=protected-access 455 try: 456 result = method(self, *args, **kwargs) 457 finally: 458 self._setattr_tracking = previous_value # pylint: disable=protected-access 459 return result 460 461 return tf_decorator.make_decorator( 462 target=method, decorator_func=_method_wrapper) 463 464 465class Trackable(object): 466 """Base class for `Trackable` objects without automatic dependencies. 467 468 This class has no __setattr__ override for performance reasons. Dependencies 469 must be added explicitly. Unless attribute assignment is performance-critical, 470 use `AutoTrackable` instead. Use `Trackable` for `isinstance` 471 checks. 472 """ 473 474 # Trackable does not do automatic dependency tracking, but uses the 475 # no_automatic_dependency_tracking decorator so it can avoid adding 476 # dependencies if a subclass is Trackable / inherits from Model (both of 477 # which have __setattr__ overrides). 478 @no_automatic_dependency_tracking 479 def _maybe_initialize_trackable(self): 480 """Initialize dependency management. 481 482 Not __init__, since most objects will forget to call it. 483 """ 484 if hasattr(self, "_unconditional_checkpoint_dependencies"): 485 # __init__ already called. This check means that we don't need 486 # Trackable.__init__() in the constructor of every TensorFlow object. 487 return 488 # A list of TrackableReference objects. Some classes implementing 489 # `Trackable`, notably `Optimizer`s, may override the 490 # _checkpoint_dependencies property with conditional dependencies 491 # (e.g. based on the current graph when saving). 492 self._unconditional_checkpoint_dependencies = [] 493 # Maps names -> Trackable objects 494 self._unconditional_dependency_names = {} 495 # Restorations for other Trackable objects on which this object may 496 # eventually depend. Maps local name -> CheckpointPosition list. Optimizers 497 # tack on conditional dependencies, and so need separate management of 498 # deferred dependencies too. 499 self._unconditional_deferred_dependencies = {} 500 # The UID of the highest assignment to this object. Used to ensure that the 501 # last requested assignment determines the final value of an object. 502 if hasattr(self, "_update_uid"): 503 raise AssertionError( 504 "Internal error: the object had an update UID set before its " 505 "initialization code was run.") 506 self._update_uid = -1 507 # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator 508 # instances, which should be checked when creating variables or other 509 # saveables. These are passed on recursively to all dependencies, since 510 # unlike object-based checkpoint restores we don't know which subgraph is 511 # being restored in advance. This mechanism is only necessary for 512 # restore-on-create when executing eagerly, and so is unused when graph 513 # building. 514 self._name_based_restores = set() 515 516 def _no_dependency(self, value): 517 """If automatic dependency tracking is enabled, ignores `value`.""" 518 return value 519 520 def _name_based_attribute_restore(self, checkpoint): 521 """Restore the object's attributes from a name-based checkpoint.""" 522 self._name_based_restores.add(checkpoint) 523 if self._update_uid < checkpoint.restore_uid: 524 checkpoint.eager_restore(self) 525 self._update_uid = checkpoint.restore_uid 526 527 @property 528 def _checkpoint_dependencies(self): 529 """All dependencies of this object. 530 531 May be overridden to include conditional dependencies. 532 533 Returns: 534 A list of `TrackableReference` objects indicating named 535 `Trackable` dependencies which should be saved along with this 536 object. 537 """ 538 return self._unconditional_checkpoint_dependencies 539 540 @property 541 def _deferred_dependencies(self): 542 """A dictionary with deferred dependencies. 543 544 Stores restorations for other Trackable objects on which this object 545 may eventually depend. May be overridden by sub-classes (e.g. Optimizers use 546 conditional dependencies based the current graph, and so need separate 547 management of deferred dependencies too). 548 549 Returns: 550 A dictionary mapping from local name to a list of CheckpointPosition 551 objects. 552 """ 553 return self._unconditional_deferred_dependencies 554 555 def _lookup_dependency(self, name): 556 """Look up a dependency by name. 557 558 May be overridden to include conditional dependencies. 559 560 Args: 561 name: The local name of the dependency. 562 Returns: 563 A `Trackable` object, or `None` if no dependency by this name was 564 found. 565 """ 566 return self._unconditional_dependency_names.get(name, None) 567 568 def _add_variable_with_custom_getter( 569 self, name, shape=None, dtype=dtypes.float32, 570 initializer=None, getter=None, overwrite=False, 571 **kwargs_for_getter): 572 """Restore-on-create for a variable be saved with this `Trackable`. 573 574 If the user has requested that this object or another `Trackable` which 575 depends on this object be restored from a checkpoint (deferred loading 576 before variable object creation), `initializer` may be ignored and the value 577 from the checkpoint used instead. 578 579 Args: 580 name: A name for the variable. Must be unique within this object. 581 shape: The shape of the variable. 582 dtype: The data type of the variable. 583 initializer: The initializer to use. Ignored if there is a deferred 584 restoration left over from a call to 585 `_restore_from_checkpoint_position`. 586 getter: The getter to wrap which actually fetches the variable. 587 overwrite: If True, disables unique name and type checks. 588 **kwargs_for_getter: Passed to the getter. 589 590 Returns: 591 The new variable object. 592 593 Raises: 594 ValueError: If the variable name is not unique. 595 """ 596 self._maybe_initialize_trackable() 597 with ops.init_scope(): 598 if context.executing_eagerly(): 599 # If this is a variable with a single Tensor stored in the checkpoint, 600 # we can set that value as an initializer rather than initializing and 601 # then assigning (when executing eagerly). This call returns None if 602 # there is nothing to restore. 603 checkpoint_initializer = self._preload_simple_restoration( 604 name=name, shape=shape) 605 else: 606 checkpoint_initializer = None 607 if (checkpoint_initializer is not None 608 and not ( 609 isinstance(initializer, CheckpointInitialValue) 610 and (initializer.restore_uid 611 > checkpoint_initializer.restore_uid))): 612 # If multiple Trackable objects are "creating" the same variable 613 # via the magic of custom getters, the one with the highest restore UID 614 # (the one called last) has to make the final initializer. If another 615 # custom getter interrupts this process by overwriting the initializer, 616 # then we'll catch that when we call _track_trackable. So this is 617 # "best effort" to set the initializer with the highest restore UID. 618 initializer = checkpoint_initializer 619 shape = None 620 new_variable = getter( 621 name=name, shape=shape, dtype=dtype, initializer=initializer, 622 **kwargs_for_getter) 623 624 # If we set an initializer and the variable processed it, tracking will not 625 # assign again. It will add this variable to our dependencies, and if there 626 # is a non-trivial restoration queued, it will handle that. This also 627 # handles slot variables. 628 if not overwrite or isinstance(new_variable, Trackable): 629 return self._track_trackable(new_variable, name=name, 630 overwrite=overwrite) 631 else: 632 # TODO(allenl): Some variable types are not yet supported. Remove this 633 # fallback once all get_variable() return types are Trackable. 634 return new_variable 635 636 def _preload_simple_restoration(self, name, shape): 637 """Return a dependency's value for restore-on-create. 638 639 Note the restoration is not deleted; if for some reason preload is called 640 and then not assigned to the variable (for example because a custom getter 641 overrides the initializer), the assignment will still happen once the 642 variable is tracked (determined based on checkpoint.restore_uid). 643 644 Args: 645 name: The object-local name of the dependency holding the variable's 646 value. 647 shape: The shape of the variable being loaded into. 648 Returns: 649 An callable for use as a variable's initializer/initial_value, or None if 650 one should not be set (either because there was no variable with this name 651 in the checkpoint or because it needs more complex deserialization). Any 652 non-trivial deserialization will happen when the variable object is 653 tracked. 654 """ 655 deferred_dependencies_list = self._deferred_dependencies.get(name, ()) 656 if not deferred_dependencies_list: 657 # Nothing to do; we don't have a restore for this dependency queued up. 658 return 659 for checkpoint_position in deferred_dependencies_list: 660 if not checkpoint_position.is_simple_variable(): 661 # If _any_ pending restoration is too complicated to fit in an 662 # initializer (because it has dependencies, or because there are 663 # multiple Tensors to restore), bail and let the general tracking code 664 # handle it. 665 return None 666 checkpoint_position = max( 667 deferred_dependencies_list, 668 key=lambda restore: restore.checkpoint.restore_uid) 669 return CheckpointInitialValue( 670 checkpoint_position=checkpoint_position, shape=shape) 671 672 def _track_trackable(self, trackable, name, overwrite=False): 673 """Declare a dependency on another `Trackable` object. 674 675 Indicates that checkpoints for this object should include variables from 676 `trackable`. 677 678 Variables in a checkpoint are mapped to `Trackable`s based on the names 679 provided when the checkpoint was written. To avoid breaking existing 680 checkpoints when modifying a class, neither variable names nor dependency 681 names (the names passed to `_track_trackable`) may change. 682 683 Args: 684 trackable: A `Trackable` which this object depends on. 685 name: A local name for `trackable`, used for loading checkpoints into 686 the correct objects. 687 overwrite: Boolean, whether silently replacing dependencies is OK. Used 688 for __setattr__, where throwing an error on attribute reassignment would 689 be inappropriate. 690 691 Returns: 692 `trackable`, for convenience when declaring a dependency and 693 assigning to a member variable in one statement. 694 695 Raises: 696 TypeError: If `trackable` does not inherit from `Trackable`. 697 ValueError: If another object is already tracked by this name. 698 """ 699 self._maybe_initialize_trackable() 700 if not isinstance(trackable, Trackable): 701 raise TypeError( 702 ("Trackable._track_trackable() passed type %s, not a " 703 "Trackable.") % (type(trackable),)) 704 new_reference = TrackableReference(name=name, ref=trackable) 705 current_object = self._lookup_dependency(name) 706 if (current_object is not None 707 and current_object is not trackable): 708 if not overwrite: 709 raise ValueError( 710 ("Called Trackable._track_trackable() with name='%s', " 711 "but a Trackable with this name is already declared as a " 712 "dependency. Names must be unique (or overwrite=True).") % (name,)) 713 # This is a weird thing to do, but we're not going to stop people from 714 # using __setattr__. 715 for index, (old_name, _) in enumerate( 716 self._unconditional_checkpoint_dependencies): 717 if name == old_name: 718 self._unconditional_checkpoint_dependencies[index] = new_reference 719 elif current_object is None: 720 self._unconditional_checkpoint_dependencies.append(new_reference) 721 self._handle_deferred_dependencies( 722 name=name, trackable=trackable) 723 self._unconditional_dependency_names[name] = trackable 724 return trackable 725 726 def _handle_deferred_dependencies(self, name, trackable): 727 """Pop and load any deferred checkpoint restores into `trackable`. 728 729 This method does not add a new dependency on `trackable`, but it does 730 check if any outstanding/deferred dependencies have been queued waiting for 731 this dependency to be added (matched based on `name`). If so, 732 `trackable` and its dependencies are restored. The restorations are 733 considered fulfilled and so are deleted. 734 735 `_track_trackable` is more appropriate for adding a 736 normal/unconditional dependency, and includes handling for deferred 737 restorations. This method allows objects such as `Optimizer` to use the same 738 restoration logic while managing conditional dependencies themselves, by 739 overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the 740 object's dependencies based on the context it is saved/restored in (a single 741 optimizer instance can have state associated with multiple graphs). 742 743 Args: 744 name: The name of the dependency within this object (`self`), used to 745 match `trackable` with values saved in a checkpoint. 746 trackable: The Trackable object to restore (inheriting from 747 `Trackable`). 748 """ 749 self._maybe_initialize_trackable() 750 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 751 deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) 752 for checkpoint_position in sorted( 753 deferred_dependencies_list, 754 key=lambda restore: restore.checkpoint.restore_uid, 755 reverse=True): 756 checkpoint_position.restore(trackable) 757 758 # Pass on any name-based restores queued in this object. 759 for name_based_restore in sorted( 760 self._name_based_restores, 761 key=lambda checkpoint: checkpoint.restore_uid, 762 reverse=True): 763 trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access 764 765 def _restore_from_checkpoint_position(self, checkpoint_position): 766 """Restore this object and its dependencies (may be deferred).""" 767 # Attempt a breadth-first traversal, since presumably the user has more 768 # control over shorter paths. If we don't have all of the dependencies at 769 # this point, the end result is not breadth-first (since other deferred 770 # traversals will happen later). 771 visit_queue = collections.deque([checkpoint_position]) 772 restore_ops = [] 773 while visit_queue: 774 current_position = visit_queue.popleft() 775 restore_ops.extend(nest.flatten( 776 current_position.trackable # pylint: disable=protected-access 777 ._single_restoration_from_checkpoint_position( 778 checkpoint_position=current_position, 779 visit_queue=visit_queue))) 780 return restore_ops 781 782 def _single_restoration_from_checkpoint_position( 783 self, checkpoint_position, visit_queue): 784 """Restore this object, and either queue its dependencies or defer them.""" 785 self._maybe_initialize_trackable() 786 checkpoint = checkpoint_position.checkpoint 787 # If the UID of this restore is lower than our current update UID, we don't 788 # need to actually restore the object. However, we should pass the 789 # restoration on to our dependencies. 790 if checkpoint.restore_uid > self._update_uid: 791 restore_ops = checkpoint_position.restore_ops() 792 self._update_uid = checkpoint.restore_uid 793 else: 794 restore_ops = () 795 for child in checkpoint_position.object_proto.children: 796 child_position = CheckpointPosition( 797 checkpoint=checkpoint, 798 proto_id=child.node_id) 799 local_object = self._lookup_dependency(child.local_name) 800 if local_object is None: 801 # We don't yet have a dependency registered with this name. Save it 802 # in case we do. 803 self._deferred_dependencies.setdefault(child.local_name, []).append( 804 child_position) 805 else: 806 if child_position.bind_object(trackable=local_object): 807 # This object's correspondence is new, so dependencies need to be 808 # visited. Delay doing it so that we get a breadth-first dependency 809 # resolution order (shallowest paths first). The caller is responsible 810 # for emptying visit_queue. 811 visit_queue.append(child_position) 812 return restore_ops 813 814 def _gather_saveables_for_checkpoint(self): 815 """Returns a dictionary of values to checkpoint with this object. 816 817 Keys in the returned dictionary are local to this object and in a separate 818 namespace from dependencies. Values may either be `SaveableObject` factories 819 or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s 820 `var_list` constructor argument). 821 822 `SaveableObjects` have a name set, which Trackable needs to generate 823 itself. So rather than returning `SaveableObjects` directly, this method 824 should return a dictionary of callables which take `name` arguments and 825 return `SaveableObjects` with that name. 826 827 If this object may also be passed to the global-name-based `tf.train.Saver`, 828 the returned callables should have a default value for their name argument 829 (i.e. be callable with no arguments). 830 831 Returned values must be saved only by this object; if any value may be 832 shared, it should instead be a dependency. For example, variable objects 833 save their own values with the key `VARIABLE_VALUE_KEY`, but objects which 834 reference variables simply add a dependency. 835 836 Returns: 837 The dictionary mapping attribute names to `SaveableObject` factories 838 described above. For example: 839 {VARIABLE_VALUE_KEY: 840 lambda name="global_name_for_this_object": 841 SaveableObject(name=name, ...)} 842 """ 843 if not hasattr(self, "get_config"): 844 return {} 845 try: 846 self.get_config() 847 except NotImplementedError: 848 return {} 849 weak_self = weakref.ref(self) 850 def _state_callback(): 851 """Serializes `self.get_config()` for saving.""" 852 dereferenced_self = weak_self() 853 if dereferenced_self: 854 try: 855 return json.dumps( 856 dereferenced_self, 857 default=serialization.get_json_type, 858 sort_keys=True).encode("utf8") 859 except TypeError: 860 # Even if get_config worked objects may have produced garbage. 861 return "" 862 else: 863 return "" 864 return {OBJECT_CONFIG_JSON_KEY: functools.partial( 865 PythonStringStateSaveable, 866 state_callback=_state_callback)} 867 868 def _list_functions_for_serialization(self): 869 """Lists the functions of this trackable to serialize. 870 871 Internal sub-classes can override this with specific logic. E.g. 872 `AutoTrackable` provides an implementation that returns the `attr` 873 that return functions. 874 875 Returns: 876 A dictionary mapping attribute names to `Function` or 877 `ConcreteFunction`. 878 """ 879 return dict() 880