1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions implementing Layer SavedModel serialization.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.keras.mixed_precision import policy 22from tensorflow.python.keras.saving.saved_model import base_serialization 23from tensorflow.python.keras.saving.saved_model import constants 24from tensorflow.python.keras.saving.saved_model import save_impl 25from tensorflow.python.keras.saving.saved_model import serialized_attributes 26from tensorflow.python.keras.utils import generic_utils 27from tensorflow.python.training.tracking import data_structures 28from tensorflow.python.util import nest 29 30 31class LayerSavedModelSaver(base_serialization.SavedModelSaver): 32 """Implements Layer SavedModel serialization.""" 33 34 @property 35 def object_identifier(self): 36 return constants.LAYER_IDENTIFIER 37 38 @property 39 def python_properties(self): 40 # TODO(kathywu): Add python property validator 41 return self._python_properties_internal() 42 43 def _python_properties_internal(self): 44 """Returns dictionary of all python properties.""" 45 # TODO(kathywu): Add support for metrics serialization. 46 # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once 47 # the python config serialization has caught up. 48 metadata = dict( 49 name=self.obj.name, 50 trainable=self.obj.trainable, 51 expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access 52 dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access 53 batch_input_shape=getattr(self.obj, '_batch_input_shape', None), 54 stateful=self.obj.stateful, 55 must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access 56 ) 57 58 metadata.update(get_serialized(self.obj)) 59 if self.obj.input_spec is not None: 60 # Layer's input_spec has already been type-checked in the property setter. 61 metadata['input_spec'] = nest.map_structure( 62 lambda x: generic_utils.serialize_keras_object(x) if x else None, 63 self.obj.input_spec) 64 if (self.obj.activity_regularizer is not None and 65 hasattr(self.obj.activity_regularizer, 'get_config')): 66 metadata['activity_regularizer'] = generic_utils.serialize_keras_object( 67 self.obj.activity_regularizer) 68 if self.obj._build_input_shape is not None: # pylint: disable=protected-access 69 metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access 70 return metadata 71 72 def objects_to_serialize(self, serialization_cache): 73 return (self._get_serialized_attributes( 74 serialization_cache).objects_to_serialize) 75 76 def functions_to_serialize(self, serialization_cache): 77 return (self._get_serialized_attributes( 78 serialization_cache).functions_to_serialize) 79 80 def _get_serialized_attributes(self, serialization_cache): 81 """Generates or retrieves serialized attributes from cache.""" 82 keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {}) 83 if self.obj in keras_cache: 84 return keras_cache[self.obj] 85 86 serialized_attr = keras_cache[self.obj] = ( 87 serialized_attributes.SerializedAttributes.new(self.obj)) 88 89 if (save_impl.should_skip_serialization(self.obj) or 90 self.obj._must_restore_from_config): # pylint: disable=protected-access 91 return serialized_attr 92 93 object_dict, function_dict = self._get_serialized_attributes_internal( 94 serialization_cache) 95 96 serialized_attr.set_and_validate_objects(object_dict) 97 serialized_attr.set_and_validate_functions(function_dict) 98 return serialized_attr 99 100 def _get_serialized_attributes_internal(self, serialization_cache): 101 """Returns dictionary of serialized attributes.""" 102 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache) 103 functions = save_impl.wrap_layer_functions(self.obj, serialization_cache) 104 # Attribute validator requires that the default save signature is added to 105 # function dict, even if the value is None. 106 functions['_default_save_signature'] = None 107 return objects, functions 108 109 110# TODO(kathywu): Move serialization utils (and related utils from 111# generic_utils.py) to a separate file. 112def get_serialized(obj): 113 with generic_utils.skip_failed_serialization(): 114 # Store the config dictionary, which may be used when reviving the object. 115 # When loading, the program will attempt to revive the object from config, 116 # and if that fails, the object will be revived from the SavedModel. 117 return generic_utils.serialize_keras_object(obj) 118 119 120class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): 121 """InputLayer serialization.""" 122 123 @property 124 def object_identifier(self): 125 return constants.INPUT_LAYER_IDENTIFIER 126 127 @property 128 def python_properties(self): 129 130 return dict( 131 class_name=type(self.obj).__name__, 132 name=self.obj.name, 133 dtype=self.obj.dtype, 134 sparse=self.obj.sparse, 135 ragged=self.obj.ragged, 136 batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access 137 config=self.obj.get_config()) 138 139 def objects_to_serialize(self, serialization_cache): 140 return {} 141 142 def functions_to_serialize(self, serialization_cache): 143 return {} 144 145 146class RNNSavedModelSaver(LayerSavedModelSaver): 147 """RNN layer serialization.""" 148 149 @property 150 def object_identifier(self): 151 return constants.RNN_LAYER_IDENTIFIER 152 153 def _get_serialized_attributes_internal(self, serialization_cache): 154 objects, functions = ( 155 super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( 156 serialization_cache)) 157 states = data_structures.wrap_or_unwrap(self.obj.states) 158 # Force the tuple into TupleWrapper which is a trackable object. The 159 # save/load code requires all the objects to be trackable. 160 # Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap() 161 # if it doesn't contains any trackable objects. 162 if isinstance(states, tuple): 163 states = data_structures._TupleWrapper(states) # pylint: disable=protected-access 164 objects['states'] = states 165 return objects, functions 166