• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A mixin class that delegates another Trackable to be used when saving.
16
17This is intended to be used with wrapper classes that cannot directly proxy the
18wrapped object (e.g. with wrapt.ObjectProxy), because there are inner attributes
19that cannot be exposed.
20
21The Wrapper class itself cannot contain any Trackable children, as only the
22delegated Trackable will be saved to checkpoint and SavedModel.
23
24This class will "disappear" and be replaced with the wrapped inner Trackable
25after a cycle of SavedModel saving and loading, unless the object is registered
26and loaded with Keras.
27"""
28
29from tensorflow.python.util.tf_export import tf_export
30
31
32@tf_export("__internal__.tracking.DelegatingTrackableMixin", v1=[])
33class DelegatingTrackableMixin(object):
34  """A mixin that delegates all Trackable methods to another trackable object.
35
36  DO NOT USE THIS UNLESS YOU ARE THE KERAS LOSS SCALE OPTIMIZER.
37
38  This class must be used with multiple inheritance. A class that subclasses
39  Trackable can also subclass this class, which causes all Trackable methods to
40  be delegated to the trackable object passed in the constructor.
41
42  A subclass can use this mixin to appear as if it were the trackable passed to
43  the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
44  mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
45  the checkpoint format for a normal optimizer. This allows a model to be saved
46  with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
47  The only difference in checkpoint format is that the loss scale is also saved
48  with a LossScaleOptimizer.
49  """
50
51  def __init__(self, trackable_obj):
52    self._trackable = trackable_obj
53
54  # pylint: disable=protected-access
55  @property
56  def _setattr_tracking(self):
57    return self._trackable._setattr_tracking
58
59  @_setattr_tracking.setter
60  def _setattr_tracking(self, value):
61    self._trackable._setattr_tracking = value
62
63  @property
64  def _update_uid(self):
65    return self._trackable._update_uid
66
67  @_update_uid.setter
68  def _update_uid(self, value):
69    self._trackable._update_uid = value
70
71  @property
72  def _unconditional_checkpoint_dependencies(self):
73    return self._trackable._unconditional_checkpoint_dependencies
74
75  @property
76  def _unconditional_dependency_names(self):
77    return self._trackable._unconditional_dependency_names
78
79  @property
80  def _name_based_restores(self):
81    return self._trackable._name_based_restores
82
83  def _maybe_initialize_trackable(self):
84    return self._trackable._maybe_initialize_trackable()
85
86  @property
87  def _object_identifier(self):
88    return self._trackable._object_identifier
89
90  @property
91  def _tracking_metadata(self):
92    return self._trackable._tracking_metadata
93
94  def _no_dependency(self, *args, **kwargs):
95    return self._trackable._no_dependency(*args, **kwargs)
96
97  def _name_based_attribute_restore(self, *args, **kwargs):
98    return self._trackable._name_based_attribute_restore(*args, **kwargs)
99
100  @property
101  def _checkpoint_dependencies(self):
102    return self._trackable._checkpoint_dependencies
103
104  @property
105  def _deferred_dependencies(self):
106    return self._trackable._deferred_dependencies
107
108  def _lookup_dependency(self, *args, **kwargs):
109    return self._trackable._lookup_dependency(*args, **kwargs)
110
111  def _add_variable_with_custom_getter(self, *args, **kwargs):
112    return self._trackable._add_variable_with_custom_getter(*args, **kwargs)
113
114  def _preload_simple_restoration(self, *args, **kwargs):
115    return self._trackable._preload_simple_restoration(*args, **kwargs)
116
117  def _track_trackable(self, *args, **kwargs):  # pylint: disable=redefined-outer-name
118    return self._trackable._track_trackable(*args, **kwargs)
119
120  def _handle_deferred_dependencies(self, name, trackable):  # pylint: disable=redefined-outer-name
121    return self._trackable._handle_deferred_dependencies(name, trackable)
122
123  def _gather_saveables_for_checkpoint(self, *args, **kwargs):
124    return self._trackable._gather_saveables_for_checkpoint(*args, **kwargs)
125
126  def _trackable_children(self, *args, **kwargs):
127    return self._trackable._trackable_children(*args, **kwargs)
128
129  def _deserialization_dependencies(self, *args, **kwargs):
130    return self._trackable._deserialization_dependencies(*args, **kwargs)
131
132  def _export_to_saved_model_graph(self, *args, **kwargs):
133    return self._trackable._export_to_saved_model_graph(*args, **kwargs)
134  # pylint: enable=protected-access
135
136