1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utility functions to save/load keras Model to/from SavedModel.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import six 23 24from tensorflow.python.client import session 25from tensorflow.python.framework import ops 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras import optimizers 28from tensorflow.python.keras.optimizer_v2 import optimizer_v2 29from tensorflow.python.keras.saving import model_from_json 30from tensorflow.python.keras.saving import saving_utils 31from tensorflow.python.keras.utils import mode_keys 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.saved_model import builder as saved_model_builder 36from tensorflow.python.saved_model import constants 37from tensorflow.python.saved_model import model_utils 38from tensorflow.python.saved_model import save as save_lib 39from tensorflow.python.saved_model import utils_impl as saved_model_utils 40from tensorflow.python.training import saver as saver_lib 41from tensorflow.python.training.tracking import graph_view 42from tensorflow.python.util import compat 43from tensorflow.python.util import nest 44from tensorflow.python.util.tf_export import keras_export 45 46 47@keras_export('keras.experimental.export_saved_model') 48def export_saved_model(model, 49 saved_model_path, 50 custom_objects=None, 51 as_text=False, 52 input_signature=None, 53 serving_only=False): 54 """Exports a `tf.keras.Model` as a Tensorflow SavedModel. 55 56 Note that at this time, subclassed models can only be saved using 57 `serving_only=True`. 58 59 The exported `SavedModel` is a standalone serialization of Tensorflow objects, 60 and is supported by TF language APIs and the Tensorflow Serving system. 61 To load the model, use the function 62 `tf.keras.experimental.load_from_saved_model`. 63 64 The `SavedModel` contains: 65 66 1. a checkpoint containing the model weights. 67 2. a `SavedModel` proto containing the Tensorflow backend graph. Separate 68 graphs are saved for prediction (serving), train, and evaluation. If 69 the model has not been compiled, then only the graph computing predictions 70 will be exported. 71 3. the model's json config. If the model is subclassed, this will only be 72 included if the model's `get_config()` method is overwritten. 73 74 Example: 75 76 ```python 77 import tensorflow as tf 78 79 # Create a tf.keras model. 80 model = tf.keras.Sequential() 81 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 82 model.summary() 83 84 # Save the tf.keras model in the SavedModel format. 85 path = '/tmp/simple_keras_model' 86 tf.keras.experimental.export_saved_model(model, path) 87 88 # Load the saved keras model back. 89 new_model = tf.keras.experimental.load_from_saved_model(path) 90 new_model.summary() 91 ``` 92 93 Args: 94 model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag 95 `serving_only` must be set to True. 96 saved_model_path: a string specifying the path to the SavedModel directory. 97 custom_objects: Optional dictionary mapping string names to custom classes 98 or functions (e.g. custom loss functions). 99 as_text: bool, `False` by default. Whether to write the `SavedModel` proto 100 in text format. Currently unavailable in serving-only mode. 101 input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used 102 to specify the expected model inputs. See `tf.function` for more details. 103 serving_only: bool, `False` by default. When this is true, only the 104 prediction graph is saved. 105 106 Raises: 107 NotImplementedError: If the model is a subclassed model, and serving_only is 108 False. 109 ValueError: If the input signature cannot be inferred from the model. 110 AssertionError: If the SavedModel directory already exists and isn't empty. 111 """ 112 if serving_only: 113 save_lib.save( 114 model, 115 saved_model_path, 116 signatures=saving_utils.trace_model_call(model, input_signature)) 117 else: 118 _save_v1_format(model, saved_model_path, custom_objects, as_text, 119 input_signature) 120 121 try: 122 _export_model_json(model, saved_model_path) 123 except NotImplementedError: 124 logging.warning('Skipped saving model JSON, subclassed model does not have ' 125 'get_config() defined.') 126 127 128def _export_model_json(model, saved_model_path): 129 """Saves model configuration as a json string under assets folder.""" 130 model_json = model.to_json() 131 model_json_filepath = os.path.join( 132 saved_model_utils.get_or_create_assets_dir(saved_model_path), 133 compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) 134 file_io.write_string_to_file(model_json_filepath, model_json) 135 136 137def _export_model_variables(model, saved_model_path): 138 """Saves model weights in checkpoint format under variables folder.""" 139 saved_model_utils.get_or_create_variables_dir(saved_model_path) 140 checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) 141 model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) 142 return checkpoint_prefix 143 144 145def _save_v1_format(model, path, custom_objects, as_text, input_signature): 146 """Exports model to v1 SavedModel format.""" 147 from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top 148 149 if not model._is_graph_network: 150 if isinstance(model, sequential.Sequential): 151 # If input shape is not directly set in the model, the exported model 152 # will infer the expected shapes of the input from the model. 153 if not model.built and input_signature is None: 154 raise ValueError( 155 'Sequential model\'s input shape is unknown. Please build the ' 156 'model, or use the input_signature argument to specify the ' 157 'model inputs.') 158 else: 159 raise NotImplementedError( 160 'Subclassed models can only be exported for serving. Please set ' 161 'argument serving_only=True.') 162 163 builder = saved_model_builder._SavedModelBuilder(path) 164 165 # Manually save variables to export them in an object-based checkpoint. This 166 # skips the `builder.add_meta_graph_and_variables()` step, which saves a 167 # named-based checkpoint. 168 # TODO(b/113134168): Add fn to Builder to save with object-based saver. 169 # TODO(b/113178242): This should only export the model json structure. Only 170 # one save is needed once the weights can be copied from the model to clone. 171 checkpoint_path = _export_model_variables(model, path) 172 173 # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that 174 # Keras models and `Estimator`s are exported with the same format. 175 # Every time a mode is exported, the code checks to see if new variables have 176 # been created (e.g. optimizer slot variables). If that is the case, the 177 # checkpoint is re-saved to include the new variables. 178 export_args = {'builder': builder, 179 'model': model, 180 'custom_objects': custom_objects, 181 'checkpoint_path': checkpoint_path, 182 'input_signature': input_signature} 183 184 has_saved_vars = False 185 if model.optimizer: 186 if isinstance(model.optimizer, (optimizers.TFOptimizer, 187 optimizer_v2.OptimizerV2)): 188 _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args) 189 has_saved_vars = True 190 _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args) 191 else: 192 logging.warning( 193 'Model was compiled with an optimizer, but the optimizer is not from ' 194 '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' 195 'graph was exported. The train and evaluate graphs were not added to ' 196 'the SavedModel.') 197 _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args) 198 199 builder.save(as_text) 200 201 202def _get_var_list(model): 203 """Returns list of all checkpointed saveable objects in the model.""" 204 var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph() 205 return var_list 206 207 208def create_placeholder(spec): 209 return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) 210 211 212def _export_mode( 213 mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, 214 input_signature): 215 """Exports a model, and optionally saves new vars from the clone model. 216 217 Args: 218 mode: A `tf.estimator.ModeKeys` string. 219 has_saved_vars: A `boolean` indicating whether the SavedModel has already 220 exported variables. 221 builder: A `SavedModelBuilder` object. 222 model: A `tf.keras.Model` object. 223 custom_objects: A dictionary mapping string names to custom classes 224 or functions. 225 checkpoint_path: String path to checkpoint. 226 input_signature: Nested TensorSpec containing the expected inputs. Can be 227 `None`, in which case the signature will be inferred from the model. 228 229 Raises: 230 ValueError: If the train/eval mode is being exported, but the model does 231 not have an optimizer. 232 """ 233 from tensorflow.python.keras import models as models_lib # pylint: disable=g-import-not-at-top 234 compile_clone = (mode != mode_keys.ModeKeys.PREDICT) 235 if compile_clone and not model.optimizer: 236 raise ValueError( 237 'Model does not have an optimizer. Cannot export mode %s' % mode) 238 239 model_graph = ops.get_default_graph() 240 with ops.Graph().as_default() as g, K.learning_phase_scope( 241 mode == mode_keys.ModeKeys.TRAIN): 242 243 if input_signature is None: 244 input_tensors = None 245 else: 246 input_tensors = nest.map_structure(create_placeholder, input_signature) 247 248 # Clone the model into blank graph. This will create placeholders for inputs 249 # and targets. 250 clone = models_lib.clone_and_build_model( 251 model, input_tensors=input_tensors, custom_objects=custom_objects, 252 compile_clone=compile_clone) 253 254 # Make sure that iterations variable is added to the global step collection, 255 # to ensure that, when the SavedModel graph is loaded, the iterations 256 # variable is returned by `tf.train.get_global_step()`. This is required for 257 # compatibility with the SavedModelEstimator. 258 if compile_clone: 259 g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) 260 261 # Extract update and train ops from train/test/predict functions. 262 train_op = None 263 if mode == mode_keys.ModeKeys.TRAIN: 264 clone._make_train_function() 265 train_op = clone.train_function.updates_op 266 elif mode == mode_keys.ModeKeys.TEST: 267 clone._make_test_function() 268 else: 269 clone._make_predict_function() 270 g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) 271 272 with session.Session().as_default(): 273 clone_var_list = _get_var_list(clone) 274 if has_saved_vars: 275 # Confirm all variables in the clone have an entry in the checkpoint. 276 status = clone.load_weights(checkpoint_path) 277 status.assert_existing_objects_matched() 278 else: 279 # Confirm that variables between the clone and model match up exactly, 280 # not counting optimizer objects. Optimizer objects are ignored because 281 # if the model has not trained, the slot variables will not have been 282 # created yet. 283 # TODO(b/113179535): Replace with trackable equivalence. 284 _assert_same_non_optimizer_objects(model, model_graph, clone, g) 285 286 # TODO(b/113178242): Use value transfer for trackable objects. 287 clone.load_weights(checkpoint_path) 288 289 # Add graph and variables to SavedModel. 290 # TODO(b/113134168): Switch to add_meta_graph_and_variables. 291 clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) 292 builder._has_saved_variables = True 293 294 # Add graph to the SavedModel builder. 295 builder.add_meta_graph( 296 model_utils.EXPORT_TAG_MAP[mode], 297 signature_def_map=_create_signature_def_map(clone, mode), 298 saver=saver_lib.Saver(clone_var_list), 299 init_op=variables.local_variables_initializer(), 300 train_op=train_op) 301 return None 302 303 304def _create_signature_def_map(model, mode): 305 """Creates a SignatureDef map from a Keras model.""" 306 inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} 307 if model.optimizer: 308 targets_dict = {x.name.split(':')[0]: x 309 for x in model.targets if x is not None} 310 inputs_dict.update(targets_dict) 311 outputs_dict = {name: x 312 for name, x in zip(model.output_names, model.outputs)} 313 metrics = saving_utils.extract_model_metrics(model) 314 315 # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables 316 # are by default not added to any collections. We are doing this here, so 317 # that metric variables get initialized. 318 local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 319 vars_to_add = set() 320 if metrics is not None: 321 from tensorflow.python.keras.metrics import Metric # pylint: disable=g-import-not-at-top 322 for key, value in six.iteritems(metrics): 323 if isinstance(value, Metric): 324 vars_to_add.update(value.variables) 325 # Convert Metric instances to (value_tensor, update_op) tuple. 326 metrics[key] = (value.result(), value.updates[0]) 327 # Remove variables that are in the local variables collection already. 328 vars_to_add = vars_to_add.difference(local_vars) 329 for v in vars_to_add: 330 ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) 331 332 export_outputs = model_utils.export_outputs_for_mode( 333 mode, 334 predictions=outputs_dict, 335 loss=model.total_loss if model.optimizer else None, 336 metrics=metrics) 337 return model_utils.build_all_signature_defs( 338 inputs_dict, 339 export_outputs=export_outputs, 340 serving_only=(mode == mode_keys.ModeKeys.PREDICT)) 341 342 343def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument 344 """Asserts model and clone contain the same trackable objects.""" 345 346 # TODO(fchollet, kathywu): make sure this works in eager mode. 347 return True 348 349 350@keras_export('keras.experimental.load_from_saved_model') 351def load_from_saved_model(saved_model_path, custom_objects=None): 352 """Loads a keras Model from a SavedModel created by `export_saved_model()`. 353 354 This function reinstantiates model state by: 355 1) loading model topology from json (this will eventually come 356 from metagraph). 357 2) loading model weights from checkpoint. 358 359 Example: 360 361 ```python 362 import tensorflow as tf 363 364 # Create a tf.keras model. 365 model = tf.keras.Sequential() 366 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 367 model.summary() 368 369 # Save the tf.keras model in the SavedModel format. 370 path = '/tmp/simple_keras_model' 371 tf.keras.experimental.export_saved_model(model, path) 372 373 # Load the saved keras model back. 374 new_model = tf.keras.experimental.load_from_saved_model(path) 375 new_model.summary() 376 ``` 377 378 Args: 379 saved_model_path: a string specifying the path to an existing SavedModel. 380 custom_objects: Optional dictionary mapping names 381 (strings) to custom classes or functions to be 382 considered during deserialization. 383 384 Returns: 385 a keras.Model instance. 386 """ 387 # restore model topology from json string 388 model_json_filepath = os.path.join( 389 compat.as_bytes(saved_model_path), 390 compat.as_bytes(constants.ASSETS_DIRECTORY), 391 compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) 392 model_json = file_io.read_file_to_string(model_json_filepath) 393 model = model_from_json(model_json, custom_objects=custom_objects) 394 395 # restore model weights 396 checkpoint_prefix = os.path.join( 397 compat.as_text(saved_model_path), 398 compat.as_text(constants.VARIABLES_DIRECTORY), 399 compat.as_text(constants.VARIABLES_FILENAME)) 400 model.load_weights(checkpoint_prefix) 401 return model 402