1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes that list&validate all attributes to serialize to SavedModel. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import def_function 22from tensorflow.python.eager import function as defun 23from tensorflow.python.keras.saving.saved_model import constants 24from tensorflow.python.keras.utils.generic_utils import LazyLoader 25from tensorflow.python.training.tracking import base as trackable 26from tensorflow.python.training.tracking.tracking import AutoTrackable 27 28# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 29# once the issue with copybara is fixed. 30# pylint:disable=g-inconsistent-quotes 31base_layer = LazyLoader( 32 "base_layer", globals(), 33 "tensorflow.python.keras.engine.base_layer") 34training_lib = LazyLoader( 35 "training_lib", globals(), 36 "tensorflow.python.keras.engine.training") 37metrics = LazyLoader("metrics", globals(), 38 "tensorflow.python.keras.metrics") 39recurrent = LazyLoader( 40 "recurrent", globals(), 41 "tensorflow.python.keras.layers.recurrent") 42# pylint:enable=g-inconsistent-quotes 43 44 45class SerializedAttributes(object): 46 """Class that tracks and validates all serialization attributes. 47 48 Keras models contain many Python-defined components. For example, the 49 trainable_variable property lists the model's trainable variables by 50 recursively retrieving the trainable variables from each of the child layers. 51 Another example is model.call, a python function that calls child layers and 52 adds ops to the backend graph. 53 54 Only Tensorflow checkpointable objects and functions can be serialized to 55 SavedModel. Serializing a Keras model as-is results in a checkpointable object 56 that does not resemble a Keras model at all. Thus, extra checkpointable 57 objects and functions must be created during serialization. 58 59 **Defining new serialized attributes** 60 Child classes should be defined using: 61 SerializedAttributes.with_attributes( 62 'name', checkpointable_objects=[...], functions=[...], copy_from=[...]) 63 This class is used to cache generated checkpointable objects and functions, 64 ensuring that new objects and functions are generated a single time. 65 66 **Usage during serialization** 67 Each Layer/Model object should have a corresponding instance of 68 SerializedAttributes. Create a new instance by calling 69 `SerializedAttributes.new(obj)`. Objects and functions may be saved using 70 `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`. 71 The properties `.checkpointable_objects` and `.functions` returns the cached 72 values. 73 74 **Adding/changing attributes to save to SavedModel** 75 1. Change the call to `SerializedAttributes.with_attributes` in the correct 76 class: 77 - CommonEndpoints: Base attributes to be added during serialization. If 78 these attributes are present in a Trackable object, it can be 79 deserialized to a Keras Model. 80 - LayerAttributes: Attributes to serialize for Layer objects. 81 - ModelAttributes: Attributes to serialize for Model objects. 82 2. Update class docstring 83 3. Update arguments to any calls to `set_and_validate_*`. For example, if 84 `call_raw_tensors` is added to the ModelAttributes function list, then 85 a `call_raw_tensors` function should be passed to 86 `set_and_validate_functions`. 87 88 **Common endpoints vs other attributes** 89 Only common endpoints are attached directly to the root object. Keras-specific 90 attributes are saved to a separate trackable object with the name "keras_api". 91 The number of objects attached to the root is limited because any naming 92 conflicts will cause user code to break. 93 94 Another reason is that this will only affect users who call 95 `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are 96 advanced users who are likely to have defined their own tf.functions and 97 trackable objects. The added Keras-specific attributes are kept out of the way 98 in the "keras_api" namespace. 99 100 Properties defined in this class may be used to filter out keras-specific 101 attributes: 102 - `functions_to_serialize`: Returns dict of functions to attach to the root 103 object. 104 - `checkpointable_objects_to_serialize`: Returns dict of objects to attach to 105 the root object (including separate trackable object containing 106 keras-specific attributes) 107 108 All changes to the serialized attributes must be backwards-compatible, so 109 attributes should not be removed or modified without sufficient justification. 110 """ 111 112 @staticmethod 113 def with_attributes( 114 name, checkpointable_objects=None, functions=None, copy_from=None): 115 """Creates a subclass with all attributes as specified in the arguments. 116 117 Args: 118 name: Name of subclass 119 checkpointable_objects: List of checkpointable objects to be serialized 120 in the SavedModel. 121 functions: List of functions to be serialized in the SavedModel. 122 copy_from: List of other SerializedAttributes subclasses. The returned 123 class will copy checkpoint objects/functions from each subclass. 124 125 Returns: 126 Child class with attributes as defined in the `checkpointable_objects` 127 and `functions` lists. 128 """ 129 checkpointable_objects = checkpointable_objects or [] 130 functions = functions or [] 131 132 if copy_from is not None: 133 for cls in copy_from: 134 checkpointable_objects.extend(cls.all_checkpointable_objects) 135 functions.extend(cls.all_functions) 136 137 classdict = { 138 'all_checkpointable_objects': set(checkpointable_objects), 139 'all_functions': set(functions)} 140 return type(name, (SerializedAttributes,), classdict) 141 142 @staticmethod 143 def new(obj): 144 """Returns a new SerializedAttribute object.""" 145 if isinstance(obj, training_lib.Model): 146 return ModelAttributes() 147 elif isinstance(obj, metrics.Metric): 148 return MetricAttributes() 149 elif isinstance(obj, recurrent.RNN): 150 return RNNAttributes() 151 elif isinstance(obj, base_layer.Layer): 152 return LayerAttributes() 153 else: 154 raise TypeError('Internal error during serialization: Expected Keras ' 155 'Layer object, got {} of type {}'.format(obj, type(obj))) 156 157 def __init__(self): 158 self._object_dict = {} 159 self._function_dict = {} 160 self._keras_trackable = AutoTrackable() 161 162 @property 163 def functions(self): 164 """Returns dictionary of all functions.""" 165 return {key: value for key, value in self._function_dict.items() 166 if value is not None} 167 168 @property 169 def checkpointable_objects(self): 170 """Returns dictionary of all checkpointable objects.""" 171 return {key: value for key, value in self._object_dict.items() 172 if value is not None} 173 174 @property 175 def functions_to_serialize(self): 176 """Returns functions to attach to the root object during serialization.""" 177 return {key: value for key, value in self.functions.items() 178 if key in CommonEndpoints.all_functions} 179 180 @property 181 def objects_to_serialize(self): 182 """Returns objects to attach to the root object during serialization.""" 183 objects = {key: value for key, value in self.checkpointable_objects.items() 184 if key in CommonEndpoints.all_checkpointable_objects} 185 objects[constants.KERAS_ATTR] = self._keras_trackable 186 return objects 187 188 def set_and_validate_functions(self, function_dict): 189 """Saves function dictionary, and validates dictionary values.""" 190 for key in self.all_functions: 191 if key in function_dict: 192 if (function_dict[key] is not None and # Not all functions are required 193 not isinstance(function_dict[key], 194 (defun.Function, def_function.Function))): 195 raise ValueError( 196 'Function dictionary contained a non-function object: {} (for key' 197 ' {})'.format(function_dict[key], key)) 198 self._function_dict[key] = function_dict[key] 199 setattr(self._keras_trackable, key, function_dict[key]) 200 else: 201 raise ValueError('Function {} missing from serialized function dict.' 202 .format(key)) 203 return self.functions 204 205 def set_and_validate_objects(self, object_dict): 206 """Saves objects to a dictionary, and validates the values.""" 207 for key in self.all_checkpointable_objects: 208 if key in object_dict: 209 if not isinstance(object_dict[key], trackable.Trackable): 210 raise ValueError( 211 'Object dictionary contained a non-trackable object: {} (for key' 212 ' {})'.format(object_dict[key], key)) 213 self._object_dict[key] = object_dict[key] 214 setattr(self._keras_trackable, key, object_dict[key]) 215 else: 216 raise ValueError( 217 'Object {} missing from serialized object dict.'.format(key)) 218 return self.checkpointable_objects 219 220 221class CommonEndpoints(SerializedAttributes.with_attributes( 222 'CommonEndpoints', 223 checkpointable_objects=['variables', 'trainable_variables', 224 'regularization_losses'], 225 functions=['__call__', 'call_and_return_all_conditional_losses', 226 '_default_save_signature'])): 227 """Common endpoints shared by all models loadable by Keras. 228 229 List of all attributes: 230 variables: List of all variables in the model and its sublayers. 231 trainable_variables: List of all trainable variables in the model and its 232 sublayers. 233 regularization_losses: List of all unconditional losses (losses not 234 dependent on the inputs) in the model and its sublayers. 235 __call__: Function that takes inputs and returns the outputs of the model 236 call function. 237 call_and_return_all_conditional_losses: Function that returns a tuple of 238 (call function outputs, list of all losses that depend on the inputs). 239 _default_save_signature: Traced model call function. This is only included 240 if the top level exported object is a Keras model. 241 """ 242 243 244class LayerAttributes(SerializedAttributes.with_attributes( 245 'LayerAttributes', 246 checkpointable_objects=['non_trainable_variables', 'layers', 'metrics', 247 'layer_regularization_losses', 'layer_metrics'], 248 functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'], 249 copy_from=[CommonEndpoints] 250 )): 251 """Layer checkpointable objects + functions that are saved to the SavedModel. 252 253 List of all attributes: 254 All attributes from CommonEndpoints 255 non_trainable_variables: List of non-trainable variables in the layer and 256 its sublayers. 257 layers: List of all sublayers. 258 metrics: List of all metrics in the layer and its sublayers. 259 call_and_return_conditional_losses: Function that takes inputs and returns a 260 tuple of (outputs of the call function, list of input-dependent losses). 261 The list of losses excludes the activity regularizer function, which is 262 separate to allow the deserialized Layer object to define a different 263 activity regularizer. 264 activity_regularizer_fn: Callable that returns the activity regularizer loss 265 layer_regularization_losses: List of losses owned only by this layer. 266 layer_metrics: List of metrics owned by this layer. 267 """ 268 269 270class ModelAttributes(SerializedAttributes.with_attributes( 271 'ModelAttributes', 272 copy_from=[LayerAttributes])): 273 """Model checkpointable objects + functions that are saved to the SavedModel. 274 275 List of all attributes: 276 All attributes from LayerAttributes (including CommonEndpoints) 277 """ 278 # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which 279 # list all losses and metrics defined by `model.compile`. 280 281 282class MetricAttributes( 283 SerializedAttributes.with_attributes( 284 'MetricAttributes', 285 checkpointable_objects=['variables'], 286 functions=[], 287 )): 288 """Attributes that are added to Metric objects when saved to SavedModel. 289 290 List of all attributes: 291 variables: list of all variables 292 """ 293 pass 294 295 296class RNNAttributes(SerializedAttributes.with_attributes( 297 'RNNAttributes', 298 checkpointable_objects=['states'], 299 copy_from=[LayerAttributes])): 300 """RNN checkpointable objects + functions that are saved to the SavedModel. 301 302 List of all attributes: 303 All attributes from LayerAttributes (including CommonEndpoints) 304 states: List of state variables 305 """ 306 307