1"""An object-local variable management scheme.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16import collections 17import weakref 18 19from tensorflow.python.eager import context 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import control_flow_ops 23from tensorflow.python.trackable import constants 24from tensorflow.python.training.saving import saveable_object 25from tensorflow.python.util import tf_contextlib 26from tensorflow.python.util import tf_decorator 27from tensorflow.python.util.tf_export import tf_export 28 29OBJECT_GRAPH_PROTO_KEY = constants.OBJECT_GRAPH_PROTO_KEY 30VARIABLE_VALUE_KEY = constants.VARIABLE_VALUE_KEY 31OBJECT_CONFIG_JSON_KEY = constants.OBJECT_CONFIG_JSON_KEY 32SaveType = constants.SaveType 33 34 35@tf_export("__internal__.tracking.TrackableReference", v1=[]) 36class TrackableReference(object): 37 """A named reference to a trackable object for use with the `Trackable` class. 38 39 These references mark named `Trackable` dependencies of a `Trackable` object 40 and should be created when overriding `Trackable._checkpoint_dependencies`. 41 42 Attributes: 43 name: The local name for this dependency. 44 ref: The `Trackable` object being referenced. 45 """ 46 47 __slots__ = ("_name", "_ref") 48 49 def __init__(self, name, ref): 50 self._name = name 51 self._ref = ref 52 53 @property 54 def name(self): 55 return self._name 56 57 @property 58 def ref(self): 59 return self._ref 60 61 def __iter__(self): 62 yield self.name 63 yield self.ref 64 65 def __repr__(self): 66 return f"{self.__class__.__name__}(name={self.name}, ref={self.ref})" 67 68 def __eq__(self, o): 69 if isinstance(o, tuple): 70 return (self.name, self.ref) == o 71 elif isinstance(o, TrackableReference): 72 return self.name == o.name and self.ref == o.ref 73 else: 74 return False 75 76 77class WeakTrackableReference(TrackableReference): 78 """TrackableReference that stores weak references.""" 79 __slots__ = () 80 81 def __init__(self, name, ref): 82 if not isinstance(self, weakref.ref): 83 ref = weakref.ref(ref) 84 super(WeakTrackableReference, self).__init__(name=name, ref=ref) 85 86 @property 87 def ref(self): 88 return self._ref() 89 90 91# TODO(bfontain): Update once sharded initialization interface is finalized. 92ShardInfo = collections.namedtuple("CheckpointInitialValueShardInfo", 93 ["shape", "offset"]) 94 95 96@tf_export("__internal__.tracking.CheckpointInitialValueCallable", v1=[]) 97class CheckpointInitialValueCallable(object): 98 """A callable object that returns a CheckpointInitialValue. 99 100 See CheckpointInitialValue for more information. 101 """ 102 103 def __init__(self, checkpoint_position): 104 self._checkpoint_position = checkpoint_position 105 106 @property 107 def checkpoint_position(self): 108 return self._checkpoint_position 109 110 def __call__(self, shape=None, dtype=None, shard_info=None): 111 # Note that the signature here is for compatibility with normal callable 112 # initializers which take shape and dtype. Although dtype isn't used, it 113 # will get passed in by a functool.partial_wrapper in places like 114 # base_layer_utils.py's make_variable. 115 return CheckpointInitialValue( 116 self._checkpoint_position, shape, shard_info=shard_info) 117 118 @property 119 def restore_uid(self): 120 return self._checkpoint_position.restore_uid 121 122 123@tf_export("__internal__.tracking.CheckpointInitialValue", v1=[]) 124class CheckpointInitialValue(ops.Tensor): 125 """Tensor wrapper for managing update UIDs in `Variables`. 126 127 When supplied as an initial value, objects of this type let a `Variable` 128 (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial 129 value came from. This allows deferred restorations to be sequenced in the 130 order the user specified them, and lets us fall back on assignment if an 131 initial value is not set (e.g. due to a custom getter interfering). 132 133 See comments in _add_variable_with_custom_getter for more information about 134 how `CheckpointInitialValue` is used. 135 """ 136 137 def __init__(self, checkpoint_position, shape=None, shard_info=None): 138 if shard_info: 139 full_shape_str = " ".join("%d" % d for d in shape) + " " 140 slice_spec = ":".join( 141 "%d,%d" % (o, s) for o, s in zip(shard_info.offset, shard_info.shape)) 142 shape_and_slice = full_shape_str + slice_spec 143 else: 144 shape_and_slice = "" 145 self.wrapped_value = checkpoint_position.value_tensors( 146 {VARIABLE_VALUE_KEY: shape_and_slice})[VARIABLE_VALUE_KEY] 147 self._checkpoint_position = checkpoint_position 148 149 def __getattr__(self, attr): 150 try: 151 return getattr(self.wrapped_value, attr) 152 except AttributeError: 153 return self.__getattribute__(attr) 154 155 @property 156 def checkpoint_position(self): 157 return self._checkpoint_position 158 159 160class NoRestoreSaveable(saveable_object.SaveableObject): 161 """Embeds a tensor in a checkpoint with no restore ops.""" 162 163 def __init__(self, tensor, name, dtype=None, device=None): 164 spec = saveable_object.SaveSpec( 165 tensor, "", name, dtype=dtype, device=device) 166 super(NoRestoreSaveable, self).__init__(tensor, [spec], name) 167 168 def restore(self, restored_tensors, restored_shapes): 169 return control_flow_ops.no_op() 170 171 172_SlotVariableRestoration = collections.namedtuple( 173 "_SlotVariableRestoration", 174 [ 175 # The checkpoint proto id of the optimizer object. 176 "optimizer_id", 177 # The checkpoint proto id of the slot variable. 178 "slot_variable_id", 179 "slot_name", 180 ]) 181 182 183@tf_export("__internal__.tracking.no_automatic_dependency_tracking", v1=[]) 184def no_automatic_dependency_tracking(method): 185 """Disables automatic dependency tracking on attribute assignment. 186 187 Use to decorate any method of a Trackable object. Attribute assignment in 188 that method will not add dependencies (also respected in Model). Harmless if 189 used in a class which does not do automatic dependency tracking (which means 190 it's safe to use in base classes which may have subclasses which also inherit 191 from Trackable). 192 193 Args: 194 method: The method to decorate. 195 196 Returns: 197 A decorated method which sets and un-sets automatic dependency tracking for 198 the object the method is called on (not thread safe). 199 """ 200 201 def _method_wrapper(self, *args, **kwargs): 202 previous_value = getattr(self, "_self_setattr_tracking", True) 203 self._self_setattr_tracking = False # pylint: disable=protected-access 204 try: 205 result = method(self, *args, **kwargs) 206 finally: 207 self._self_setattr_tracking = previous_value # pylint: disable=protected-access 208 return result 209 210 return tf_decorator.make_decorator( 211 target=method, decorator_func=_method_wrapper) 212 213 214@tf_contextlib.contextmanager 215def no_manual_dependency_tracking_scope(obj): 216 """A context that disables manual dependency tracking for the given `obj`. 217 218 Sometimes library methods might track objects on their own and we might want 219 to disable that and do the tracking on our own. One can then use this context 220 manager to disable the tracking the library method does and do your own 221 tracking. 222 223 For example: 224 225 class TestLayer(tf.keras.Layer): 226 def build(): 227 with no_manual_dependency_tracking_scope(self): 228 var = self.add_variable("name1") # Creates a var and doesn't track it 229 self._track_trackable("name2", var) # We track variable with name `name2` 230 231 Args: 232 obj: A trackable object. 233 234 Yields: 235 a scope in which the object doesn't track dependencies manually. 236 """ 237 # pylint: disable=protected-access 238 previous_value = getattr(obj, "_manual_tracking", True) 239 obj._manual_tracking = False 240 try: 241 yield 242 finally: 243 obj._manual_tracking = previous_value 244 245 246@tf_contextlib.contextmanager 247def no_automatic_dependency_tracking_scope(obj): 248 """A context that disables automatic dependency tracking when assigning attrs. 249 250 Objects that inherit from Autotrackable automatically creates dependencies 251 to trackable objects through attribute assignments, and wraps data structures 252 (lists or dicts) with trackable classes. This scope may be used to temporarily 253 disable this behavior. This works similar to the decorator 254 `no_automatic_dependency_tracking`. 255 256 Example usage: 257 ``` 258 model = tf.keras.Model() 259 model.arr1 = [] # Creates a ListWrapper object 260 with no_automatic_dependency_tracking_scope(model): 261 model.arr2 = [] # Creates a regular, untracked python list 262 ``` 263 264 Args: 265 obj: A trackable object. 266 267 Yields: 268 a scope in which the object doesn't track dependencies. 269 """ 270 previous_value = getattr(obj, "_setattr_tracking", True) 271 obj._setattr_tracking = False # pylint: disable=protected-access 272 try: 273 yield 274 finally: 275 obj._setattr_tracking = previous_value # pylint: disable=protected-access 276 277 278@tf_export("__internal__.tracking.Trackable", v1=[]) 279class Trackable(object): 280 """Base class for `Trackable` objects without automatic dependencies. 281 282 This class has no __setattr__ override for performance reasons. Dependencies 283 must be added explicitly. Unless attribute assignment is performance-critical, 284 use `AutoTrackable` instead. Use `Trackable` for `isinstance` 285 checks. 286 """ 287 288 # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with 289 # _self_. We have some properties to forward semi-public attributes to their 290 # _self_ equivalents. 291 292 @property 293 def _setattr_tracking(self): 294 if not hasattr(self, "_self_setattr_tracking"): 295 self._self_setattr_tracking = True 296 return self._self_setattr_tracking 297 298 @_setattr_tracking.setter 299 def _setattr_tracking(self, value): 300 self._self_setattr_tracking = value 301 302 @property 303 def _update_uid(self): 304 return self._self_update_uid 305 306 @_update_uid.setter 307 def _update_uid(self, value): 308 self._self_update_uid = value 309 310 @property 311 def _unconditional_checkpoint_dependencies(self): 312 return self._self_unconditional_checkpoint_dependencies 313 314 @property 315 def _unconditional_dependency_names(self): 316 return self._self_unconditional_dependency_names 317 318 @property 319 def _name_based_restores(self): 320 return self._self_name_based_restores 321 322 # Trackable does not do automatic dependency tracking, but uses the 323 # no_automatic_dependency_tracking decorator so it can avoid adding 324 # dependencies if a subclass is Trackable / inherits from Model (both of 325 # which have __setattr__ overrides). 326 @no_automatic_dependency_tracking 327 def _maybe_initialize_trackable(self): 328 """Initialize dependency management. 329 330 Not __init__, since most objects will forget to call it. 331 """ 332 if hasattr(self, "_self_unconditional_checkpoint_dependencies"): 333 # __init__ already called. This check means that we don't need 334 # Trackable.__init__() in the constructor of every TensorFlow object. 335 return 336 # A list of TrackableReference objects. Some classes implementing 337 # `Trackable`, notably `Optimizer`s, may override the 338 # _checkpoint_dependencies property with conditional dependencies 339 # (e.g. based on the current graph when saving). 340 self._self_unconditional_checkpoint_dependencies = [] 341 # Maps names -> Trackable objects 342 self._self_unconditional_dependency_names = {} 343 # Restorations for other Trackable objects on which this object may 344 # eventually depend. Maps local name -> CheckpointPosition list. Optimizers 345 # tack on conditional dependencies, and so need separate management of 346 # deferred dependencies too. 347 self._self_unconditional_deferred_dependencies = {} 348 # The UID of the highest assignment to this object. Used to ensure that the 349 # last requested assignment determines the final value of an object. 350 if hasattr(self, "_self_update_uid"): 351 raise AssertionError( 352 "Internal error: the object had an update UID set before its " 353 "initialization code was run.") 354 self._self_update_uid = -1 355 # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator 356 # instances, which should be checked when creating variables or other 357 # saveables. These are passed on recursively to all dependencies, since 358 # unlike object-based checkpoint restores we don't know which subgraph is 359 # being restored in advance. This mechanism is only necessary for 360 # restore-on-create when executing eagerly, and so is unused when graph 361 # building. 362 self._self_name_based_restores = set() 363 364 # Dictionary of SaveableObjects factories. This dictionary is defined when 365 # the object is loaded from the SavedModel. When writing a custom class, 366 # prefer overriding "_gather_saveables_from_checkpoint" to using this 367 # attribute. 368 self._self_saveable_object_factories = {} 369 370 @property 371 def _object_identifier(self): 372 """String used to identify this object in a SavedModel. 373 374 THIS FIELD HAS BEEN DEPRECATED IN FAVOR OF THE NAME REGISTERED WITH 375 `register_serializable`. 376 377 Generally, the object identifier is constant across objects of the same 378 class, while the metadata field is used for instance-specific data. 379 380 Returns: 381 String object identifier. 382 """ 383 return "_generic_user_object" 384 385 def _no_dependency(self, value): 386 """If automatic dependency tracking is enabled, ignores `value`.""" 387 return value 388 389 def _name_based_attribute_restore(self, checkpoint): 390 """Restore the object's attributes from a name-based checkpoint.""" 391 self._self_name_based_restores.add(checkpoint) 392 if self._self_update_uid < checkpoint.restore_uid: 393 checkpoint.eager_restore(self) 394 self._self_update_uid = checkpoint.restore_uid 395 396 @property 397 def _checkpoint_dependencies(self): 398 """All dependencies of this object. 399 400 May be overridden to include conditional dependencies. 401 402 Returns: 403 A list of `TrackableReference` objects indicating named 404 `Trackable` dependencies which should be saved along with this 405 object. 406 """ 407 return self._self_unconditional_checkpoint_dependencies 408 409 @property 410 def _deferred_dependencies(self): 411 """A dictionary with deferred dependencies. 412 413 Stores restorations for other Trackable objects on which this object 414 may eventually depend. May be overridden by sub-classes (e.g. Optimizers use 415 conditional dependencies based the current graph, and so need separate 416 management of deferred dependencies too). 417 418 Returns: 419 A dictionary mapping from local name to a list of CheckpointPosition 420 objects. 421 """ 422 return self._self_unconditional_deferred_dependencies 423 424 def _lookup_dependency(self, name): 425 """Look up a dependency by name. 426 427 May be overridden to include conditional dependencies. 428 429 Args: 430 name: The local name of the dependency. 431 432 Returns: 433 A `Trackable` object, or `None` if no dependency by this name was 434 found. 435 """ 436 return self._self_unconditional_dependency_names.get(name, None) 437 438 def _add_variable_with_custom_getter(self, 439 name, 440 shape=None, 441 dtype=dtypes.float32, 442 initializer=None, 443 getter=None, 444 overwrite=False, 445 **kwargs_for_getter): 446 """Restore-on-create for a variable be saved with this `Trackable`. 447 448 If the user has requested that this object or another `Trackable` which 449 depends on this object be restored from a checkpoint (deferred loading 450 before variable object creation), `initializer` may be ignored and the value 451 from the checkpoint used instead. 452 453 Args: 454 name: A name for the variable. Must be unique within this object. 455 shape: The shape of the variable. 456 dtype: The data type of the variable. 457 initializer: The initializer to use. Ignored if there is a deferred 458 restoration stored in the Trackable. 459 getter: The getter to wrap which actually fetches the variable. 460 overwrite: If True, disables unique name and type checks. 461 **kwargs_for_getter: Passed to the getter. 462 463 Returns: 464 The new variable object. 465 466 Raises: 467 ValueError: If the variable name is not unique. 468 """ 469 self._maybe_initialize_trackable() 470 with ops.init_scope(): 471 if context.executing_eagerly(): 472 # If this is a variable with a single Tensor stored in the checkpoint, 473 # we can set that value as an initializer rather than initializing and 474 # then assigning (when executing eagerly). This call returns None if 475 # there is nothing to restore. 476 checkpoint_initializer = self._preload_simple_restoration(name=name) 477 else: 478 checkpoint_initializer = None 479 if (checkpoint_initializer is not None and 480 not (isinstance(initializer, CheckpointInitialValueCallable) and 481 (initializer.restore_uid > checkpoint_initializer.restore_uid))): 482 # If multiple Trackable objects are "creating" the same variable 483 # via the magic of custom getters, the one with the highest restore UID 484 # (the one called last) has to make the final initializer. If another 485 # custom getter interrupts this process by overwriting the initializer, 486 # then we'll catch that when we call _track_trackable. So this is 487 # "best effort" to set the initializer with the highest restore UID. 488 initializer = checkpoint_initializer 489 new_variable = getter( 490 name=name, 491 shape=shape, 492 dtype=dtype, 493 initializer=initializer, 494 **kwargs_for_getter) 495 496 # If we set an initializer and the variable processed it, tracking will not 497 # assign again. It will add this variable to our dependencies, and if there 498 # is a non-trivial restoration queued, it will handle that. This also 499 # handles slot variables. 500 if not overwrite or isinstance(new_variable, Trackable): 501 return self._track_trackable(new_variable, name=name, overwrite=overwrite) 502 else: 503 # TODO(allenl): Some variable types are not yet supported. Remove this 504 # fallback once all get_variable() return types are Trackable. 505 return new_variable 506 507 def _preload_simple_restoration(self, name): 508 """Return a dependency's value for restore-on-create. 509 510 Note the restoration is not deleted; if for some reason preload is called 511 and then not assigned to the variable (for example because a custom getter 512 overrides the initializer), the assignment will still happen once the 513 variable is tracked (determined based on checkpoint.restore_uid). 514 515 Args: 516 name: The object-local name of the dependency holding the variable's 517 value. 518 519 Returns: 520 An callable for use as a variable's initializer/initial_value, or None if 521 one should not be set (either because there was no variable with this name 522 in the checkpoint or because it needs more complex deserialization). Any 523 non-trivial deserialization will happen when the variable object is 524 tracked. 525 """ 526 deferred_dependencies_list = self._deferred_dependencies.get(name, ()) 527 if not deferred_dependencies_list: 528 # Nothing to do; we don't have a restore for this dependency queued up. 529 return 530 for checkpoint_position in deferred_dependencies_list: 531 if not checkpoint_position.is_simple_variable(): 532 # If _any_ pending restoration is too complicated to fit in an 533 # initializer (because it has dependencies, or because there are 534 # multiple Tensors to restore), bail and let the general tracking code 535 # handle it. 536 return None 537 checkpoint_position = max( 538 deferred_dependencies_list, 539 key=lambda restore: restore.checkpoint.restore_uid) 540 return CheckpointInitialValueCallable( 541 checkpoint_position=checkpoint_position) 542 543 def _track_trackable(self, trackable, name, overwrite=False): 544 """Declare a dependency on another `Trackable` object. 545 546 Indicates that checkpoints for this object should include variables from 547 `trackable`. 548 549 Variables in a checkpoint are mapped to `Trackable`s based on the names 550 provided when the checkpoint was written. To avoid breaking existing 551 checkpoints when modifying a class, neither variable names nor dependency 552 names (the names passed to `_track_trackable`) may change. 553 554 Args: 555 trackable: A `Trackable` which this object depends on. 556 name: A local name for `trackable`, used for loading checkpoints into the 557 correct objects. 558 overwrite: Boolean, whether silently replacing dependencies is OK. Used 559 for __setattr__, where throwing an error on attribute reassignment would 560 be inappropriate. 561 562 Returns: 563 `trackable`, for convenience when declaring a dependency and 564 assigning to a member variable in one statement. 565 566 Raises: 567 TypeError: If `trackable` does not inherit from `Trackable`. 568 ValueError: If another object is already tracked by this name. 569 """ 570 self._maybe_initialize_trackable() 571 if not isinstance(trackable, Trackable): 572 raise TypeError( 573 "Trackable._track_trackable() can only be used to track objects of " 574 f"type Trackable. Got type {type(trackable)}.") 575 if not getattr(self, "_manual_tracking", True): 576 return trackable 577 new_reference = TrackableReference(name=name, ref=trackable) 578 current_object = self._lookup_dependency(name) 579 if (current_object is not None and current_object is not trackable): 580 if not overwrite: 581 raise ValueError( 582 f"Called Trackable._track_trackable() with name='{name}', " 583 "but a Trackable with this name is already declared as a " 584 "dependency. Names must be unique (or overwrite=True).") 585 # This is a weird thing to do, but we're not going to stop people from 586 # using __setattr__. 587 for index, (old_name, _) in enumerate( 588 self._self_unconditional_checkpoint_dependencies): 589 if name == old_name: 590 self._self_unconditional_checkpoint_dependencies[ 591 index] = new_reference 592 elif current_object is None: 593 self._self_unconditional_checkpoint_dependencies.append(new_reference) 594 self._handle_deferred_dependencies(name=name, trackable=trackable) 595 self._self_unconditional_dependency_names[name] = trackable 596 return trackable 597 598 def _handle_deferred_dependencies(self, name, trackable): 599 """Pop and load any deferred checkpoint restores into `trackable`. 600 601 This method does not add a new dependency on `trackable`, but it does 602 check if any outstanding/deferred dependencies have been queued waiting for 603 this dependency to be added (matched based on `name`). If so, 604 `trackable` and its dependencies are restored. The restorations are 605 considered fulfilled and so are deleted. 606 607 `_track_trackable` is more appropriate for adding a 608 normal/unconditional dependency, and includes handling for deferred 609 restorations. This method allows objects such as `Optimizer` to use the same 610 restoration logic while managing conditional dependencies themselves, by 611 overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the 612 object's dependencies based on the context it is saved/restored in (a single 613 optimizer instance can have state associated with multiple graphs). 614 615 Args: 616 name: The name of the dependency within this object (`self`), used to 617 match `trackable` with values saved in a checkpoint. 618 trackable: The Trackable object to restore (inheriting from `Trackable`). 619 """ 620 self._maybe_initialize_trackable() 621 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 622 deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) 623 for checkpoint_position in sorted( 624 deferred_dependencies_list, 625 key=lambda restore: restore.checkpoint.restore_uid, 626 reverse=True): 627 checkpoint_position.restore(trackable) 628 629 # Pass on any name-based restores queued in this object. 630 for name_based_restore in sorted( 631 self._self_name_based_restores, 632 key=lambda checkpoint: checkpoint.restore_uid, 633 reverse=True): 634 trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access 635 636 def _gather_saveables_for_checkpoint(self): 637 """Returns a dictionary of values to checkpoint with this object. 638 639 NOTE: This method is deprecated, please use `_serialize_to_tensors` and 640 `_restore_from_tensors` instead. 641 642 Keys in the returned dictionary are local to this object and in a separate 643 namespace from dependencies. Values may either be `SaveableObject` factories 644 or variables easily converted to `SaveableObject`s (as in 645 `tf.compat.v1.train.Saver`'s 646 `var_list` constructor argument). 647 648 `SaveableObjects` have a name set, which Trackable needs to generate 649 itself. So rather than returning `SaveableObjects` directly, this method 650 should return a dictionary of callables which take `name` arguments and 651 return `SaveableObjects` with that name. 652 653 If this object may also be passed to the global-name-based 654 `tf.compat.v1.train.Saver`, 655 the returned callables should have a default value for their name argument 656 (i.e. be callable with no arguments). 657 658 Returned values must be saved only by this object; if any value may be 659 shared, it should instead be a dependency. For example, variable objects 660 save their own values with the key `VARIABLE_VALUE_KEY`, but objects which 661 reference variables simply add a dependency. 662 663 Returns: 664 The dictionary mapping attribute names to `SaveableObject` factories 665 described above. For example: 666 {VARIABLE_VALUE_KEY: 667 lambda name="global_name_for_this_object": 668 SaveableObject(name=name, ...)} 669 """ 670 # TODO(kathywu): In order to remove this circular dependency, remove all 671 # external calls to _gather_saveables_for_checkpoint. 672 # pylint: disable=g-import-not-at-top 673 from tensorflow.python.training.saving import saveable_object_util 674 # pylint: enable=g-import-not-at-top 675 if saveable_object_util.trackable_has_serialize_to_tensor(self): 676 677 def create_saveable(name=""): 678 return saveable_object_util.TrackableSaveable(self, name) 679 680 return {"": create_saveable} 681 else: 682 return getattr(self, "_self_saveable_object_factories", {}) 683 684 def _serialize_to_tensors(self): 685 """Gathers tensors to save to the checkpoint. 686 687 You should only override `_serialize_to_tensors` and `_restore_from_tensors` 688 if you are defining a custom resource or variable with custom ops. 689 690 Otherwise, please store the state of your trackable in `tf.Variable` objects 691 and add them to Trackable object hierarchy using `setattr` (for subclasses 692 of `AutoTrackable`) or overriding the `_trackable_children` method. 693 694 For an example of a valid implementation of these two methods, please see 695 `DenseHashTable`. 696 697 **Invalid implementation** 698 699 ```` 700 class NamedTrackable(Trackable): 701 def __init__(self, name: str): 702 self.name = name 703 def _serialize_to_tensors(self): 704 return {"name": self.name} 705 def _restore_from_tensors(self, restored_tensors): 706 self.name = restored_tensors["name"] 707 ``` 708 709 In this example, `NamedTrackable` can be saved and restored from 710 checkpoints, but is incompatible with SavedModel, which tries to convert 711 the serialize/restore functions into tf.functions. This fails because 712 attribute assignment (`self.attr = new_value`) is not graph-friendly. 713 714 **Suggested fix** 715 716 ``` 717 class NamedTrackable(Trackable): 718 def __init__(self, name: str): 719 self.name = tf.Variable(name) 720 721 def _trackable_children(self): 722 return {"name": self.name} 723 ``` 724 725 If the `name` attribute should be saved to the checkpoint, then convert it 726 a `tf.Variable`. 727 728 Returns: 729 A dictionary mapping names to tensors. 730 """ 731 raise NotImplementedError 732 733 def _restore_from_tensors(self, restored_tensors): 734 """Restores checkpointed values to this `Trackable`. 735 736 Please see the documentation for `Trackable._serialize_to_tensors`. 737 738 Args: 739 restored_tensors: A dictionary mapping names to tensors. The keys to this 740 dictionary matches the names passed to _serialize_to_tensors. 741 742 Returns: 743 An op that runs the restoration. 744 """ 745 raise NotImplementedError 746 747 def _map_resources(self, save_options): # pylint: disable=unused-argument 748 """Makes new resource handle ops corresponding to existing resource tensors. 749 750 Internal sub-classes can override this to inform model saving how to add new 751 resource handle ops to the main GraphDef of a SavedModel (TF 1.x style 752 graph), which allows session based APIs (e.g, C++ loader API) to interact 753 with resources owned by this object. 754 755 Args: 756 save_options: A tf.saved_model.SaveOptions instance. 757 758 Returns: 759 A tuple of (object_map, resource_map): 760 object_map: A dictionary mapping from objects that hold existing 761 resource tensors to replacement objects created to hold the new 762 resource tensors. 763 resource_map: A dictionary mapping from existing resource tensors to 764 newly created resource tensors. 765 """ 766 return {}, {} 767 768 def _serialize_to_proto(self, object_proto=None, **kwargs): 769 """Returns a proto of any type to be saved into the SavedModel. 770 771 Trackable classes decorated with `register_serializable` should overwrite 772 this method to save metadata for this object to the SavedModel. The proto 773 returned by this function will be passed to `_deserialize_from_proto` in the 774 form of a `google.protobuf.Any` proto. 775 776 This data is only saved and used by the Python API. Existing C++ loading 777 APIs such as `tensorflow::LoadSavedModel` will not read this field at all. 778 779 Args: 780 object_proto: A `SavedObject` proto that may be filled by this function. 781 Only the core serializable types (Variable, Function, Constant, Asset) 782 should modify this argument. 783 **kwargs: Future keyword arguments passed to the object during saving. 784 785 Returns: 786 A proto that serializes this class's type. 787 """ 788 del object_proto, kwargs # Unused. 789 790 return None 791 792 @classmethod 793 def _deserialize_from_proto(cls, 794 proto=None, 795 dependencies=None, 796 object_proto=None, 797 export_dir=None, 798 asset_file_def=None, 799 operation_attributes=None, 800 **kwargs): 801 """Returns a new object restored by the SavedModel. 802 803 Trackable classes decorated with `register_serializable` should overwrite 804 this method to change how the object is loaded from SavedModel. By default, 805 the object is initialized with no arguments. 806 807 Example: 808 809 ``` 810 def _serialize_to_proto(self, **unused_kwargs): 811 return Message(name="a") 812 813 @classmethod 814 def _deserialize_from_proto(cls, proto, **unused_kwargs): 815 if proto.Is(Message.DESCRIPTOR): 816 unpacked = Message() 817 proto.Unpack(unpacked) 818 return cls(unpacked.name) 819 else: 820 return cls() 821 ``` 822 823 This function is only used by the Python API. C++ and TensorFlow Serving do 824 not have access to your registered class and cannot execute any of the 825 non-tf.functions attached to the Python class. However, all signatures and 826 tf.functions are still accessible. 827 828 **Avoid creating duplicate trackables** 829 830 SavedModel is saved by recursively gathering all of the trackables and their 831 children. SavedModel loading reverses those steps by creating all 832 trackables, then reconnecting the children trackables to their parents using 833 `Trackable._add_trackable_child`. 834 835 That means that if `_deserialize_from_proto` calls the `__init__` function, 836 which creates all of the children trackables, then those children end up 837 being created *twice*. 838 839 To avoid this, structure your code so that Trackables are not created 840 when deserialized from SavedModel: 841 842 ``` 843 @register_serializable() 844 class Serializable(trackable): 845 def __init __(self, from_proto=False): 846 create_non_trackable_objects() 847 if not from_proto: 848 create_variables_and_other_trackables() 849 850 def _deserialize_from_proto(cls, **kwargs): 851 return cls(from_proto=True) 852 853 def _add_trackable_child(self, name, value): 854 self.__setattr__(name, value) 855 ``` 856 857 Args: 858 proto: A `google.protobuf.Any` proto read from the `SavedModel`. 859 dependencies: A dictionary mapping names to dependencies (see 860 `_deserialization_dependencies`) 861 object_proto: The `SavedObject` proto for this object. 862 export_dir: The `SavedModel` directory 863 asset_file_def: The `MetaGraphDef`'s `asset_file_def` field. 864 operation_attributes: Dictionary mapping nodes to attribute from the 865 imported `GraphDef`. 866 **kwargs: Future keyword arguments passed to the object when loading. 867 868 Returns: 869 A new object. 870 """ 871 del (proto, dependencies, object_proto, export_dir, asset_file_def, 872 operation_attributes, kwargs) 873 874 return cls() 875 876 def _add_trackable_child(self, name, value): 877 """Restores a connection between trackables when loading from SavedModel. 878 879 SavedModel stores both the object metadata and its list of children. When 880 loading, this function is used along with `_deserialize_from_proto` to load 881 objects from the SavedModel: First, all saved objects are created with 882 `_deserialize_from_proto`. After that is complete, the children are 883 connected using `_add_trackable_child`. 884 885 **Example** 886 887 `tf.Module`, `tf.keras.Model` and Keras layers use `__setattr__` to track 888 children. This is why users can call `model.v = tf.Variable(...)`, and the 889 variable will be automatically saved to the checkpoint. The implementation 890 of this method for the listed objects is: 891 892 ``` 893 def _add_trackable_child(self, name, value): 894 self.__setattr__(name, value) 895 ``` 896 897 Args: 898 name: The name of the connection between the parent and child `Trackable`. 899 value: The child `Trackable` object. 900 """ 901 self._track_trackable(value, name, overwrite=True) 902 903 def _deserialization_dependencies(self, children): 904 """Returns a dictionary containing `Trackables` that this object depends on. 905 906 Dependencies define the order to serialize and deserialize objects in the 907 SavedModel. For example: 908 909 class A(Trackable): 910 b = B() 911 def _deserialization_dependencies(self, children): 912 return {'b': self.b} 913 914 class B(Trackable): 915 pass 916 917 We say that object `a=A()` depends on `a.b`. 918 919 Dependencies are guaranteed to be serialized and deserialized before the 920 object depending on them. The following methods use dependencies: 921 - `_deserialize_from_proto` [loading] 922 923 SavedModel loads with the bottom-up approach, by first creating all objects 924 in the order defined by the dependencies, then connecting the children. 925 926 Unlike `_trackable_children`, this function does not define the 927 `SavedObjectGraph`. It only changes the order in which things are 928 saved/loaded. Therefore, if there are dependencies that are not in the 929 `SavedObjectGraph`, saving will fail. 930 931 Args: 932 children: Dict returned from `_trackable_children`. 933 934 Returns: 935 A dictionary mapping names to `Trackable`. 936 """ 937 del children # Unused. 938 return {} 939 940 def _trackable_children(self, 941 save_type=SaveType.CHECKPOINT, 942 cache=None, 943 **kwargs): 944 """Returns this object's `Trackable` attributes. 945 946 This method is used to build the object graph (or the object hierarchy, 947 in pickling terms) for checkpoint save/restore, and `SavedModel` export. 948 949 Override this method to define the children of this instance. Please read 950 the implementation restrictions: 951 952 **Rule 1: All children must be convertable to `Trackable`.** 953 954 Must pass `isinstance` check or `converter.convert_to_trackable`. 955 956 **Rule 2: [Checkpoint-only] Do not create new objects.** 957 958 When saving to a `SavedModel`, this method is called *exactly once* for each 959 `Trackable` in the object graph. When saving or restoring from a checkpoint, 960 this method may be called *multiple times*. Thus, this method may create 961 new Trackables when `save_type == SaveType.SAVEDMODEL` but not when 962 `save_type == SaveType.CHECKPOINT`. 963 964 When saving to `SavedModel`, new `Trackable` children can be created to save 965 non-Trackable attributes to the `SavedModel`. In the example below, `hyper` 966 is a regular python float hyperparameter. To save this value, a new Variable 967 is created to store the value of `hyper`: 968 969 ``` 970 def __init__(self): 971 self.hyper = 1e-5 972 973 def _trackable_children(self, save_type, **unused_kwargs): 974 # Correct implementation 975 children = {} 976 if format == 'saved_model': 977 children['hyper'] = tf.Variable(self.hyper) 978 return children 979 ``` 980 981 An incorrect implementation of `_trackable_children` is shown below. This 982 function would cause failures when loading the checkpoint, and calling 983 `load_status.assert_consumed()` or 984 `load_status.assert_existing_objects_matched`. If you want a value to be 985 saved in the checkpoint, hyper must be defined as a `tf.Variable` from the 986 start. 987 988 ``` 989 def _trackable_children(self, save_type, **unused_kwargs): 990 # Incorrect implementation 991 return {'hyper': tf.Variable(self.hyper)} 992 ``` 993 994 **Rule 3: [`SavedModel`-only] Watch out for un-traced tf.functions.** 995 996 At the begining of `_trackable_children`, always call 997 `get_concrete_function()` for any `tf.function` that has an input signature. 998 999 When `tf.functions` are saved to `SavedModel`, any `tf.functions` that have 1000 an input signature and has never been called is traced at export time in 1001 order to copy the op graph into the `SavedModel`. `tf.functions` that are 1002 traced for the first time are allowed to create new state: 1003 1004 1005 ``` 1006 @tf.function(input_signature=[]): 1007 def fn(self); 1008 if self.v is None: 1009 self.v = tf.Variable(1.) 1010 return self.v 1011 ``` 1012 1013 A problem occurs when there is a `Trackable` that returns `fn` as one of its 1014 children and `self.v` has not been created yet. When `fn` is traced, 1015 `self.v` is added to the `Trackable`, but `SavedModel` does not see this 1016 modification since the `Trackable`'s children have already been gathered. 1017 1018 Therefore, as a precaution, call `get_concrete_function()` at the very 1019 start of `_trackable_children` to ensure that the function is traced: 1020 1021 1022 ``` 1023 def _trackable_children(self): 1024 self.fn.get_concrete_function() 1025 return {"v": self.v, "fn": self.fn} 1026 ``` 1027 1028 Args: 1029 save_type: A string, can be 'savedmodel' or 'checkpoint'. Defaults to 1030 SaveType.CHECKPOINT. 1031 cache: May be `None`, or a dictionary. When `save_type == savedmodel`, a 1032 new cache is created at the start of the SavedModel export, and shared 1033 between all `Trackables` in the same object graph. This cache may be 1034 used for advanced saving functionality. 1035 **kwargs: Additional kwargs that may be added at a later time. 1036 1037 Returns: 1038 Dictionary mapping names to child trackables. 1039 """ 1040 del save_type, cache, kwargs # Unused. 1041 1042 self._maybe_initialize_trackable() 1043 return {name: ref for name, ref in self._checkpoint_dependencies} 1044 1045 def _export_to_saved_model_graph(self, 1046 object_map=None, 1047 tensor_map=None, 1048 options=None, 1049 **kwargs): 1050 """Creates a copy of this object's tensors onto SavedModel graph. 1051 1052 Needs to be overridden if the class contains tensors that must be saved 1053 into the graph. This method should update the `object_map` and `tensor_map` 1054 dictionaries. 1055 1056 This method is called on all nodes in the Trackable Graph (generated by 1057 `_trackable_children`). The nodes are traversed in the order defined by 1058 `_deserialization_dependencies` 1059 1060 All usages of _map_resources should be migrated to this method. 1061 1062 Args: 1063 object_map: A dictionary that maps original Trackables to the copied 1064 Trackables. This only needs to be updated if the object is a 1065 tf.function, or if the copied tensors are necessary for checkpointing 1066 this object. 1067 tensor_map: Dictionary mapping original tensors to copied tensors. 1068 options: A `tf.saved_model.SaveOptions` object. 1069 **kwargs: Additional kwargs that may be added at a later time. 1070 1071 Returns: 1072 Flat list of original tensors that have been copied. 1073 """ 1074 del kwargs # Unused. 1075 self_object_map, self_tensor_map = self._map_resources(options) 1076 object_map.update(self_object_map) 1077 tensor_map.update(self_tensor_map) 1078 return list(self_tensor_map.keys()) 1079