1"""An object-local variable management scheme.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import weakref 22 23from tensorflow.python import pywrap_tensorflow 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import gen_io_ops as io_ops 28from tensorflow.python.util import nest 29 30# A key indicating a variable's value in an object's checkpointed Tensors 31# (Checkpointable._gather_tensors_for_checkpoint). If this is the only key and 32# the object has no dependencies, then its value may be restored on object 33# creation (avoiding double assignment when executing eagerly). 34VARIABLE_VALUE_KEY = "VARIABLE_VALUE" 35 36_CheckpointableReference = collections.namedtuple( 37 "_CheckpointableReference", 38 [ 39 # The local name for this dependency. 40 "name", 41 # The Checkpointable object being referenced. 42 "ref" 43 ]) 44 45 46class CheckpointInitialValue(ops.Tensor): 47 """Tensor wrapper for managing update UIDs in `Variables`. 48 49 When supplied as an initial value, objects of this type let a `Variable` 50 (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial 51 value came from. This allows deferred restorations to be sequenced in the 52 order the user specified them, and lets us fall back on assignment if an 53 initial value is not set (e.g. due to a custom getter interfering). 54 55 See comments in _add_variable_with_custom_getter for more information about 56 how `CheckpointInitialValue` is used. 57 """ 58 59 def __init__(self, checkpoint_position, shape=None): 60 self.wrapped_value = checkpoint_position.restore_ops()[ 61 VARIABLE_VALUE_KEY] 62 if shape: 63 # We need to set the static shape information on the initializer if 64 # possible so we don't get a variable with an unknown shape. 65 self.wrapped_value.set_shape(shape) 66 self._checkpoint_position = checkpoint_position 67 68 @property 69 def __class__(self): 70 return (self.wrapped_value.__class__, CheckpointInitialValue) 71 72 def __getattr__(self, attr): 73 try: 74 return getattr(self.wrapped_value, attr) 75 except AttributeError: 76 return self.__getattribute__(attr) 77 78 @property 79 def checkpoint_position(self): 80 return self._checkpoint_position 81 82 83class _CheckpointPosition(object): 84 """Indicates a position within a `_Checkpoint`.""" 85 86 def __init__(self, checkpoint, proto_id): 87 """Specify an object within a checkpoint. 88 89 Args: 90 checkpoint: A _Checkpoint object. 91 proto_id: The index of this object in CheckpointableObjectGraph.nodes. 92 """ 93 self._checkpoint = checkpoint 94 self._proto_id = proto_id 95 96 def restore(self, checkpointable): 97 """Restore this value into `checkpointable`.""" 98 if self.bind_object(checkpointable): 99 # This object's correspondence with a checkpointed object is new, so 100 # process deferred restorations for it and its dependencies. 101 restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access 102 if restore_ops: 103 self._checkpoint.restore_ops.extend(restore_ops) 104 105 def bind_object(self, checkpointable): 106 """Set a checkpoint<->object correspondence and process slot variables. 107 108 Args: 109 checkpointable: The object to record a correspondence for. 110 Returns: 111 True if this is a new assignment, False if this object has already been 112 mapped to a checkpointed `Object` proto. 113 Raises: 114 AssertionError: If another object is already bound to the `Object` proto. 115 """ 116 checkpoint = self.checkpoint 117 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) 118 if current_assignment is None: 119 checkpoint.object_by_proto_id[self._proto_id] = checkpointable 120 for deferred_slot_restoration in ( 121 checkpoint.deferred_slot_restorations.pop(self._proto_id, ())): 122 checkpointable._create_or_restore_slot_variable( # pylint: disable=protected-access 123 slot_variable_position=_CheckpointPosition( 124 checkpoint=checkpoint, 125 proto_id=deferred_slot_restoration.slot_variable_id), 126 variable=deferred_slot_restoration.original_variable, 127 slot_name=deferred_slot_restoration.slot_name) 128 for slot_restoration in checkpoint.slot_restorations.pop( 129 self._proto_id, ()): 130 optimizer_object = checkpoint.object_by_proto_id.get( 131 slot_restoration.optimizer_id, None) 132 if optimizer_object is None: 133 # The optimizer has not yet been created or tracked. Record in the 134 # checkpoint that the slot variables need to be restored when it is. 135 checkpoint.deferred_slot_restorations.setdefault( 136 slot_restoration.optimizer_id, []).append( 137 _DeferredSlotVariableRestoration( 138 original_variable=checkpointable, 139 slot_variable_id=slot_restoration.slot_variable_id, 140 slot_name=slot_restoration.slot_name)) 141 else: 142 optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access 143 slot_variable_position=_CheckpointPosition( 144 checkpoint=checkpoint, 145 proto_id=slot_restoration.slot_variable_id), 146 variable=checkpointable, 147 slot_name=slot_restoration.slot_name) 148 return True # New assignment 149 else: 150 # The object was already mapped for this checkpoint load, which means 151 # we don't need to do anything besides check that the mapping is 152 # consistent (if the dependency DAG is not a tree then there are 153 # multiple paths to the same object). 154 if current_assignment is not checkpointable: 155 raise AssertionError( 156 ("Unable to load the checkpoint into this object graph. Either " 157 "the Checkpointable object references in the Python program " 158 "have changed in an incompatible way, or the checkpoint was " 159 "generated in an incompatible program.\n\nTwo checkpoint " 160 "references resolved to different objects (%s and %s).") 161 % (current_assignment, checkpointable)) 162 return False # Not a new assignment 163 164 def is_simple_variable(self): 165 """Determine whether this value is restorable with a Tensor initializer.""" 166 attributes = self.object_proto.attributes 167 return (len(attributes) == 1 168 and attributes[0].name == VARIABLE_VALUE_KEY 169 and not self.object_proto.children) 170 171 def restore_ops(self): 172 """Create restore ops for this object's attributes.""" 173 restore_tensors = {} 174 for serialized_tensor in self.object_proto.attributes: 175 checkpoint_key = serialized_tensor.checkpoint_key 176 dtype = self._checkpoint.dtype_map[checkpoint_key] 177 base_type = dtype.base_dtype 178 with ops.init_scope(): 179 restore, = io_ops.restore_v2( 180 prefix=self._checkpoint.save_path, 181 tensor_names=[checkpoint_key], 182 shape_and_slices=[""], 183 dtypes=[base_type], 184 name="%s_checkpoint_read" % (serialized_tensor.name,)) 185 restore_tensors[serialized_tensor.name] = restore 186 return restore_tensors 187 188 @property 189 def checkpoint(self): 190 return self._checkpoint 191 192 @property 193 def checkpointable(self): 194 return self._checkpoint.object_by_proto_id[self._proto_id] 195 196 @property 197 def object_proto(self): 198 return self._checkpoint.object_graph_proto.nodes[self._proto_id] 199 200 @property 201 def restore_uid(self): 202 return self._checkpoint.restore_uid 203 204 def __repr__(self): 205 return repr(self.object_proto) 206 207 208_DeferredSlotVariableRestoration = collections.namedtuple( 209 "_DeferredSlotVariableRestoration", 210 [ 211 "original_variable", 212 "slot_variable_id", 213 "slot_name", 214 ] 215) 216 217_SlotVariableRestoration = collections.namedtuple( 218 "_SlotVariableRestoration", 219 [ 220 # The checkpoint proto id of the optimizer object. 221 "optimizer_id", 222 # The checkpoint proto id of the slot variable. 223 "slot_variable_id", 224 "slot_name", 225 ]) 226 227 228class _Checkpoint(object): 229 """Holds the status of an object-based checkpoint load.""" 230 231 def __init__(self, object_graph_proto, save_path): 232 """Specify the checkpoint being loaded. 233 234 Args: 235 object_graph_proto: The CheckpointableObjectGraph protocol buffer 236 associated with this checkpoint. 237 save_path: The path to the checkpoint, as returned by 238 `tf.train.latest_checkpoint`. 239 """ 240 self.object_graph_proto = object_graph_proto 241 self.restore_uid = ops.uid() 242 # Dictionary mapping from an id in the protocol buffer flat array to 243 # Checkpointable Python objects. This mapping may be deferred if a 244 # checkpoint is restored before all dependencies have been tracked. Uses 245 # weak references so that partial restorations don't create reference cycles 246 # (as objects with deferred dependencies will generally have references to 247 # this object). 248 self.object_by_proto_id = weakref.WeakValueDictionary() 249 self.save_path = save_path 250 reader = pywrap_tensorflow.NewCheckpointReader(save_path) 251 self.dtype_map = reader.get_variable_to_dtype_map() 252 # When graph building, contains a list of ops to run to restore objects from 253 # this checkpoint. 254 self.restore_ops = [] 255 # A mapping from optimizer proto ids to lists of slot variables to be 256 # restored when the optimizer is tracked. Only includes slot variables whose 257 # regular variables have already been created, and only for optimizer 258 # objects which have not yet been created/tracked. 259 self.deferred_slot_restorations = {} 260 # A mapping from variable proto ids to lists of slot variables to be 261 # restored when the variable is created/tracked. These get shifted over to 262 # deferred_slot_restorations if the optimizer hasn't been created when that 263 # happens. 264 self.slot_restorations = {} 265 for node_index, node in enumerate(self.object_graph_proto.nodes): 266 for slot_reference in node.slot_variables: 267 # `node` refers to an `Optimizer`, since only these have slot variables. 268 self.slot_restorations.setdefault( 269 slot_reference.original_variable_node_id, []).append( 270 _SlotVariableRestoration( 271 optimizer_id=node_index, 272 slot_variable_id=slot_reference.slot_variable_node_id, 273 slot_name=slot_reference.slot_name)) 274 275 276class CheckpointableBase(object): 277 """Base class for `Checkpointable` objects without automatic dependencies. 278 279 This class has no __setattr__ override for performance reasons. Dependencies 280 must be added explicitly. Unless attribute assignment is performance-critical, 281 use `Checkpointable` instead. Use `CheckpointableBase` for `isinstance` 282 checks. 283 """ 284 285 def _maybe_initialize_checkpointable(self): 286 """Initialize dependency management. 287 288 Not __init__, since most objects will forget to call it. 289 """ 290 if hasattr(self, "_checkpoint_dependencies"): 291 # __init__ already called. This check means that we don't need 292 # Checkpointable.__init__() in the constructor of every TensorFlow object. 293 return 294 # A list of _CheckpointableReference objects. 295 self._checkpoint_dependencies = [] 296 # Maps names -> Checkpointable objects 297 self._dependency_names = {} 298 # Restorations for other Checkpointable objects on which this object may 299 # eventually depend. 300 self._deferred_dependencies = {} # local name -> _CheckpointPosition list 301 # The UID of the highest assignment to this object. Used to ensure that the 302 # last requested assignment determines the final value of an object. 303 if hasattr(self, "_update_uid"): 304 raise AssertionError( 305 "Internal error: the object had an update UID set before its " 306 "initialization code was run.") 307 self._update_uid = -1 308 309 def _add_variable_with_custom_getter( 310 self, name, shape=None, dtype=dtypes.float32, 311 initializer=None, getter=None, **kwargs_for_getter): 312 """Restore-on-create for a variable be saved with this `Checkpointable`. 313 314 If the user has requested that this object or another `Checkpointable` which 315 depends on this object be restored from a checkpoint (deferred loading 316 before variable object creation), `initializer` may be ignored and the value 317 from the checkpoint used instead. 318 319 Args: 320 name: A name for the variable. Must be unique within this object. 321 shape: The shape of the variable. 322 dtype: The data type of the variable. 323 324 initializer: The initializer to use. Ignored if there is a deferred 325 restoration left over from a call to 326 `_restore_from_checkpoint_position`. 327 328 getter: The getter to wrap which actually fetches the variable. 329 **kwargs_for_getter: Passed to the getter. 330 331 Returns: 332 The new variable object. 333 334 Raises: 335 ValueError: If the variable name is not unique. 336 """ 337 self._maybe_initialize_checkpointable() 338 if name in self._dependency_names: 339 raise ValueError( 340 ("A variable named '%s' already exists in this Checkpointable, but " 341 "Checkpointable._add_variable called to create another with " 342 "that name. Variable names must be unique within a Checkpointable " 343 "object.") % (name,)) 344 if context.in_eager_mode(): 345 # If this is a variable with a single Tensor stored in the checkpoint, we 346 # can set that value as an initializer rather than initializing and then 347 # assigning (when executing eagerly). This call returns None if there is 348 # nothing to restore. 349 checkpoint_initializer = self._preload_simple_restoration( 350 name=name, shape=shape) 351 else: 352 checkpoint_initializer = None 353 if (checkpoint_initializer is not None 354 and not ( 355 isinstance(initializer, CheckpointInitialValue) 356 and initializer.restore_uid > checkpoint_initializer.restore_uid)): 357 # If multiple Checkpointable objects are "creating" the same variable via 358 # the magic of custom getters, the one with the highest restore UID (the 359 # one called last) has to make the final initializer. If another custom 360 # getter interrupts this process by overwriting the initializer, then 361 # we'll catch that when we call _track_checkpointable. So this is "best 362 # effort" to set the initializer with the highest restore UID. 363 initializer = checkpoint_initializer 364 shape = None 365 366 new_variable = getter( 367 name=name, shape=shape, dtype=dtype, initializer=initializer, 368 **kwargs_for_getter) 369 370 # If we set an initializer and the variable processed it, tracking will not 371 # assign again. It will add this variable to our dependencies, and if there 372 # is a non-trivial restoration queued, it will handle that. This also 373 # handles slot variables. 374 return self._track_checkpointable(new_variable, name=name) 375 376 def _preload_simple_restoration(self, name, shape): 377 """Return a dependency's value for restore-on-create. 378 379 Note the restoration is not deleted; if for some reason preload is called 380 and then not assigned to the variable (for example because a custom getter 381 overrides the initializer), the assignment will still happen once the 382 variable is tracked (determined based on checkpoint.restore_uid). 383 384 Args: 385 name: The object-local name of the dependency holding the variable's 386 value. 387 shape: The shape of the variable being loaded into. 388 Returns: 389 An callable for use as a variable's initializer/initial_value, or None if 390 one should not be set (either because there was no variable with this name 391 in the checkpoint or because it needs more complex deserialization). Any 392 non-trivial deserialization will happen when the variable object is 393 tracked. 394 """ 395 deferred_dependencies_list = self._deferred_dependencies.get(name, ()) 396 if not deferred_dependencies_list: 397 # Nothing to do; we don't have a restore for this dependency queued up. 398 return 399 for checkpoint_position in deferred_dependencies_list: 400 if not checkpoint_position.is_simple_variable(): 401 # If _any_ pending restoration is too complicated to fit in an 402 # initializer (because it has dependencies, or because there are 403 # multiple Tensors to restore), bail and let the general tracking code 404 # handle it. 405 return None 406 checkpoint_position = max( 407 deferred_dependencies_list, 408 key=lambda restore: restore.checkpoint.restore_uid) 409 return CheckpointInitialValue( 410 checkpoint_position=checkpoint_position, shape=shape) 411 412 def _track_checkpointable(self, checkpointable, name, overwrite=False): 413 """Declare a dependency on another `Checkpointable` object. 414 415 Indicates that checkpoints for this object should include variables from 416 `checkpointable`. 417 418 Variables in a checkpoint are mapped to `Checkpointable`s based on names if 419 provided when the checkpoint was written, but otherwise use the order those 420 `Checkpointable`s were declared as dependencies. 421 422 To avoid breaking existing checkpoints when modifying a class, neither 423 variable names nor dependency names (the names passed to 424 `track_checkpointable`) may change. 425 426 Args: 427 checkpointable: A `Checkpointable` which this object depends on. 428 name: A local name for `checkpointable`, used for loading checkpoints into 429 the correct objects. 430 overwrite: Boolean, whether silently replacing dependencies is OK. Used 431 for __setattr__, where throwing an error on attribute reassignment would 432 be inappropriate. 433 434 Returns: 435 `checkpointable`, for convenience when declaring a dependency and 436 assigning to a member variable in one statement. 437 438 Raises: 439 TypeError: If `checkpointable` does not inherit from `Checkpointable`. 440 ValueError: If another object is already tracked by this name. 441 """ 442 self._maybe_initialize_checkpointable() 443 if not isinstance(checkpointable, CheckpointableBase): 444 raise TypeError( 445 ("Checkpointable._track_checkpointable() passed type %s, not a " 446 "Checkpointable.") % (type(checkpointable),)) 447 new_reference = _CheckpointableReference(name=name, ref=checkpointable) 448 if (name in self._dependency_names 449 and self._dependency_names[name] is not checkpointable): 450 if not overwrite: 451 raise ValueError( 452 ("Called Checkpointable._track_checkpointable() with name='%s', " 453 "but a Checkpointable with this name is already declared as a " 454 "dependency. Names must be unique (or overwrite=True).") % (name,)) 455 # This is a weird thing to do, but we're not going to stop people from 456 # using __setattr__. 457 for index, (old_name, _) in enumerate(self._checkpoint_dependencies): 458 if name == old_name: 459 self._checkpoint_dependencies[index] = new_reference 460 else: 461 self._checkpoint_dependencies.append(new_reference) 462 463 self._dependency_names[name] = checkpointable 464 deferred_dependency_list = self._deferred_dependencies.pop(name, None) 465 if deferred_dependency_list is not None: 466 for checkpoint_position in deferred_dependency_list: 467 checkpoint_position.restore(checkpointable=checkpointable) 468 return checkpointable 469 470 def _restore_from_checkpoint_position(self, checkpoint_position): 471 """Restore this object and its dependencies (may be deferred).""" 472 # Attempt a breadth-first traversal, since presumably the user has more 473 # control over shorter paths. If we don't have all of the dependencies at 474 # this point, the end result is not breadth-first (since other deferred 475 # traversals will happen later). 476 visit_queue = collections.deque([checkpoint_position]) 477 restore_ops = [] 478 while visit_queue: 479 current_position = visit_queue.popleft() 480 restore_ops.extend(nest.flatten( 481 current_position.checkpointable # pylint: disable=protected-access 482 ._single_restoration_from_checkpoint_position( 483 checkpoint_position=current_position, 484 visit_queue=visit_queue))) 485 return restore_ops 486 487 def _single_restoration_from_checkpoint_position( 488 self, checkpoint_position, visit_queue): 489 """Restore this object, and either queue its dependencies or defer them.""" 490 self._maybe_initialize_checkpointable() 491 checkpoint = checkpoint_position.checkpoint 492 # If the UID of this restore is lower than our current update UID, we don't 493 # need to actually restore the object. However, we should pass the 494 # restoration on to our dependencies. 495 if checkpoint.restore_uid > self._update_uid: 496 restore_op = self._scatter_tensors_from_checkpoint( 497 checkpoint_position.restore_ops()) 498 self._update_uid = checkpoint.restore_uid 499 else: 500 restore_op = () 501 for child in checkpoint_position.object_proto.children: 502 child_position = _CheckpointPosition( 503 checkpoint=checkpoint, 504 proto_id=child.node_id) 505 local_object = self._dependency_names.get(child.local_name, None) 506 if local_object is None: 507 # We don't yet have a dependency registered with this name. Save it 508 # in case we do. 509 self._deferred_dependencies.setdefault(child.local_name, []).append( 510 child_position) 511 else: 512 if child_position.bind_object(checkpointable=local_object): 513 # This object's correspondence is new, so dependencies need to be 514 # visited. Delay doing it so that we get a breadth-first dependency 515 # resolution order (shallowest paths first). The caller is responsible 516 # for emptying visit_queue. 517 visit_queue.append(child_position) 518 return restore_op 519 520 def _scatter_tensors_from_checkpoint(self, attributes): 521 """Restores this object from a checkpoint. 522 523 Args: 524 attributes: A dictionary of Tensors, with key corresponding to those 525 returned from _gather_tensors_for_checkpoint. 526 Returns: 527 A restore op to run (if graph building). 528 """ 529 if attributes: 530 raise AssertionError( 531 ("A Checkpointable object which was not expecting any data received " 532 "some from a checkpoint. (Got %s)") % (attributes,)) 533 return () # No restore ops 534 535 def _gather_tensors_for_checkpoint(self): 536 """Returns a dictionary of Tensors to save with this object.""" 537 return {} 538 539 540class Checkpointable(CheckpointableBase): 541 """Manages dependencies on other objects. 542 543 `Checkpointable` objects may have dependencies: other `Checkpointable` objects 544 which should be saved if the object declaring the dependency is saved. A 545 correctly saveable program has a dependency graph such that if changing a 546 global variable affects an object (e.g. changes the behavior of any of its 547 methods) then there is a chain of dependencies from the influenced object to 548 the variable. 549 550 Dependency edges have names, and are created implicitly when a 551 `Checkpointable` object is assigned to an attribute of another 552 `Checkpointable` object. For example: 553 554 ``` 555 obj = Checkpointable() 556 obj.v = ResourceVariable(0.) 557 ``` 558 559 The `Checkpointable` object `obj` now has a dependency named "v" on a 560 variable. 561 562 `Checkpointable` objects may specify `Tensor`s to be saved and restored 563 directly (e.g. a `Variable` indicating how to save itself) rather than through 564 dependencies on other objects. See 565 `Checkpointable._scatter_tensors_from_checkpoint` and 566 `Checkpointable._gather_tensors_for_checkpoint` for details. 567 """ 568 569 def __setattr__(self, name, value): 570 """Support self.foo = checkpointable syntax.""" 571 # Perform the attribute assignment, and potentially call other __setattr__ 572 # overrides such as that for tf.keras.Model. 573 super(Checkpointable, self).__setattr__(name, value) 574 if isinstance(value, CheckpointableBase): 575 self._track_checkpointable( 576 value, name=name, 577 # Allow the user to switch the Checkpointable which is tracked by this 578 # name, since assigning a new variable to an attribute has 579 # historically been fine (e.g. Adam did this). 580 # TODO(allenl): Should this be a warning once Checkpointable save/load 581 # is usable? 582 overwrite=True) 583