1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A mixin class that delegates another Trackable to be used when saving. 16 17This is intended to be used with wrapper classes that cannot directly proxy the 18wrapped object (e.g. with wrapt.ObjectProxy), because there are inner attributes 19that cannot be exposed. 20 21The Wrapper class itself cannot contain any Trackable children, as only the 22delegated Trackable will be saved to checkpoint and SavedModel. 23 24This class will "disappear" and be replaced with the wrapped inner Trackable 25after a cycle of SavedModel saving and loading, unless the object is registered 26and loaded with Keras. 27""" 28 29from tensorflow.python.util.tf_export import tf_export 30 31 32@tf_export("__internal__.tracking.DelegatingTrackableMixin", v1=[]) 33class DelegatingTrackableMixin(object): 34 """A mixin that delegates all Trackable methods to another trackable object. 35 36 DO NOT USE THIS UNLESS YOU ARE THE KERAS LOSS SCALE OPTIMIZER. 37 38 This class must be used with multiple inheritance. A class that subclasses 39 Trackable can also subclass this class, which causes all Trackable methods to 40 be delegated to the trackable object passed in the constructor. 41 42 A subclass can use this mixin to appear as if it were the trackable passed to 43 the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this 44 mixin, so that the checkpoint format for a LossScaleOptimizer is identical to 45 the checkpoint format for a normal optimizer. This allows a model to be saved 46 with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa. 47 The only difference in checkpoint format is that the loss scale is also saved 48 with a LossScaleOptimizer. 49 """ 50 51 def __init__(self, trackable_obj): 52 self._trackable = trackable_obj 53 54 # pylint: disable=protected-access 55 @property 56 def _setattr_tracking(self): 57 return self._trackable._setattr_tracking 58 59 @_setattr_tracking.setter 60 def _setattr_tracking(self, value): 61 self._trackable._setattr_tracking = value 62 63 @property 64 def _update_uid(self): 65 return self._trackable._update_uid 66 67 @_update_uid.setter 68 def _update_uid(self, value): 69 self._trackable._update_uid = value 70 71 @property 72 def _unconditional_checkpoint_dependencies(self): 73 return self._trackable._unconditional_checkpoint_dependencies 74 75 @property 76 def _unconditional_dependency_names(self): 77 return self._trackable._unconditional_dependency_names 78 79 @property 80 def _name_based_restores(self): 81 return self._trackable._name_based_restores 82 83 def _maybe_initialize_trackable(self): 84 return self._trackable._maybe_initialize_trackable() 85 86 @property 87 def _object_identifier(self): 88 return self._trackable._object_identifier 89 90 @property 91 def _tracking_metadata(self): 92 return self._trackable._tracking_metadata 93 94 def _no_dependency(self, *args, **kwargs): 95 return self._trackable._no_dependency(*args, **kwargs) 96 97 def _name_based_attribute_restore(self, *args, **kwargs): 98 return self._trackable._name_based_attribute_restore(*args, **kwargs) 99 100 @property 101 def _checkpoint_dependencies(self): 102 return self._trackable._checkpoint_dependencies 103 104 @property 105 def _deferred_dependencies(self): 106 return self._trackable._deferred_dependencies 107 108 def _lookup_dependency(self, *args, **kwargs): 109 return self._trackable._lookup_dependency(*args, **kwargs) 110 111 def _add_variable_with_custom_getter(self, *args, **kwargs): 112 return self._trackable._add_variable_with_custom_getter(*args, **kwargs) 113 114 def _preload_simple_restoration(self, *args, **kwargs): 115 return self._trackable._preload_simple_restoration(*args, **kwargs) 116 117 def _track_trackable(self, *args, **kwargs): # pylint: disable=redefined-outer-name 118 return self._trackable._track_trackable(*args, **kwargs) 119 120 def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name 121 return self._trackable._handle_deferred_dependencies(name, trackable) 122 123 def _gather_saveables_for_checkpoint(self, *args, **kwargs): 124 return self._trackable._gather_saveables_for_checkpoint(*args, **kwargs) 125 126 def _trackable_children(self, *args, **kwargs): 127 return self._trackable._trackable_children(*args, **kwargs) 128 129 def _deserialization_dependencies(self, *args, **kwargs): 130 return self._trackable._deserialization_dependencies(*args, **kwargs) 131 132 def _export_to_saved_model_graph(self, *args, **kwargs): 133 return self._trackable._export_to_saved_model_graph(*args, **kwargs) 134 # pylint: enable=protected-access 135 136