1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Deprecated experimental Keras SavedModel implementation.""" 16 17import os 18import warnings 19 20from tensorflow.python.client import session 21from tensorflow.python.framework import ops 22from tensorflow.python.keras import backend 23from tensorflow.python.keras import optimizer_v1 24from tensorflow.python.keras.optimizer_v2 import optimizer_v2 25from tensorflow.python.keras.saving import model_config 26from tensorflow.python.keras.saving import saving_utils 27from tensorflow.python.keras.saving import utils_v1 as model_utils 28from tensorflow.python.keras.utils import mode_keys 29from tensorflow.python.keras.utils.generic_utils import LazyLoader 30from tensorflow.python.lib.io import file_io 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import gfile 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.saved_model import builder as saved_model_builder 35from tensorflow.python.saved_model import constants 36from tensorflow.python.saved_model import save as save_lib 37from tensorflow.python.training import saver as saver_lib 38from tensorflow.python.training.tracking import graph_view 39from tensorflow.python.util import compat 40from tensorflow.python.util import nest 41from tensorflow.python.util.tf_export import keras_export 42 43# To avoid circular dependencies between keras/engine and keras/saving, 44# code in keras/saving must delay imports. 45 46# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 47# once the issue with copybara is fixed. 48# pylint:disable=g-inconsistent-quotes 49metrics_lib = LazyLoader("metrics_lib", globals(), 50 "tensorflow.python.keras.metrics") 51models_lib = LazyLoader("models_lib", globals(), 52 "tensorflow.python.keras.models") 53sequential = LazyLoader( 54 "sequential", globals(), 55 "tensorflow.python.keras.engine.sequential") 56# pylint:enable=g-inconsistent-quotes 57 58 59# File name for json format of SavedModel. 60SAVED_MODEL_FILENAME_JSON = 'saved_model.json' 61 62 63@keras_export(v1=['keras.experimental.export_saved_model']) 64def export_saved_model(model, 65 saved_model_path, 66 custom_objects=None, 67 as_text=False, 68 input_signature=None, 69 serving_only=False): 70 """Exports a `tf.keras.Model` as a Tensorflow SavedModel. 71 72 Note that at this time, subclassed models can only be saved using 73 `serving_only=True`. 74 75 The exported `SavedModel` is a standalone serialization of Tensorflow objects, 76 and is supported by TF language APIs and the Tensorflow Serving system. 77 To load the model, use the function 78 `tf.keras.experimental.load_from_saved_model`. 79 80 The `SavedModel` contains: 81 82 1. a checkpoint containing the model weights. 83 2. a `SavedModel` proto containing the Tensorflow backend graph. Separate 84 graphs are saved for prediction (serving), train, and evaluation. If 85 the model has not been compiled, then only the graph computing predictions 86 will be exported. 87 3. the model's json config. If the model is subclassed, this will only be 88 included if the model's `get_config()` method is overwritten. 89 90 Example: 91 92 ```python 93 import tensorflow as tf 94 95 # Create a tf.keras model. 96 model = tf.keras.Sequential() 97 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 98 model.summary() 99 100 # Save the tf.keras model in the SavedModel format. 101 path = '/tmp/simple_keras_model' 102 tf.keras.experimental.export_saved_model(model, path) 103 104 # Load the saved keras model back. 105 new_model = tf.keras.experimental.load_from_saved_model(path) 106 new_model.summary() 107 ``` 108 109 Args: 110 model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag 111 `serving_only` must be set to True. 112 saved_model_path: a string specifying the path to the SavedModel directory. 113 custom_objects: Optional dictionary mapping string names to custom classes 114 or functions (e.g. custom loss functions). 115 as_text: bool, `False` by default. Whether to write the `SavedModel` proto 116 in text format. Currently unavailable in serving-only mode. 117 input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used 118 to specify the expected model inputs. See `tf.function` for more details. 119 serving_only: bool, `False` by default. When this is true, only the 120 prediction graph is saved. 121 122 Raises: 123 NotImplementedError: If the model is a subclassed model, and serving_only is 124 False. 125 ValueError: If the input signature cannot be inferred from the model. 126 AssertionError: If the SavedModel directory already exists and isn't empty. 127 """ 128 warnings.warn('`tf.keras.experimental.export_saved_model` is deprecated' 129 'and will be removed in a future version. ' 130 'Please use `model.save(..., save_format="tf")` or ' 131 '`tf.keras.models.save_model(..., save_format="tf")`.') 132 if serving_only: 133 save_lib.save( 134 model, 135 saved_model_path, 136 signatures=saving_utils.trace_model_call(model, input_signature)) 137 else: 138 _save_v1_format(model, saved_model_path, custom_objects, as_text, 139 input_signature) 140 141 try: 142 _export_model_json(model, saved_model_path) 143 except NotImplementedError: 144 logging.warning('Skipped saving model JSON, subclassed model does not have ' 145 'get_config() defined.') 146 147 148def _export_model_json(model, saved_model_path): 149 """Saves model configuration as a json string under assets folder.""" 150 model_json = model.to_json() 151 model_json_filepath = os.path.join( 152 _get_or_create_assets_dir(saved_model_path), 153 compat.as_text(SAVED_MODEL_FILENAME_JSON)) 154 with gfile.Open(model_json_filepath, 'w') as f: 155 f.write(model_json) 156 157 158def _export_model_variables(model, saved_model_path): 159 """Saves model weights in checkpoint format under variables folder.""" 160 _get_or_create_variables_dir(saved_model_path) 161 checkpoint_prefix = _get_variables_path(saved_model_path) 162 model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) 163 return checkpoint_prefix 164 165 166def _save_v1_format(model, path, custom_objects, as_text, input_signature): 167 """Exports model to v1 SavedModel format.""" 168 if not model._is_graph_network: # pylint: disable=protected-access 169 if isinstance(model, sequential.Sequential): 170 # If input shape is not directly set in the model, the exported model 171 # will infer the expected shapes of the input from the model. 172 if not model.built: 173 raise ValueError('Weights for sequential model have not yet been ' 174 'created. Weights are created when the Model is first ' 175 'called on inputs or `build()` is called with an ' 176 '`input_shape`, or the first layer in the model has ' 177 '`input_shape` during construction.') 178 # TODO(kathywu): Build the model with input_signature to create the 179 # weights before _export_model_variables(). 180 else: 181 raise NotImplementedError( 182 'Subclassed models can only be exported for serving. Please set ' 183 'argument serving_only=True.') 184 185 builder = saved_model_builder._SavedModelBuilder(path) # pylint: disable=protected-access 186 187 # Manually save variables to export them in an object-based checkpoint. This 188 # skips the `builder.add_meta_graph_and_variables()` step, which saves a 189 # named-based checkpoint. 190 # TODO(b/113134168): Add fn to Builder to save with object-based saver. 191 # TODO(b/113178242): This should only export the model json structure. Only 192 # one save is needed once the weights can be copied from the model to clone. 193 checkpoint_path = _export_model_variables(model, path) 194 195 # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that 196 # Keras models and `Estimator`s are exported with the same format. 197 # Every time a mode is exported, the code checks to see if new variables have 198 # been created (e.g. optimizer slot variables). If that is the case, the 199 # checkpoint is re-saved to include the new variables. 200 export_args = {'builder': builder, 201 'model': model, 202 'custom_objects': custom_objects, 203 'checkpoint_path': checkpoint_path, 204 'input_signature': input_signature} 205 206 has_saved_vars = False 207 if model.optimizer: 208 if isinstance(model.optimizer, (optimizer_v1.TFOptimizer, 209 optimizer_v2.OptimizerV2)): 210 _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args) 211 has_saved_vars = True 212 _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args) 213 else: 214 logging.warning( 215 'Model was compiled with an optimizer, but the optimizer is not from ' 216 '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' 217 'graph was exported. The train and evaluate graphs were not added to ' 218 'the SavedModel.') 219 _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args) 220 221 builder.save(as_text) 222 223 224def _get_var_list(model): 225 """Returns list of all checkpointed saveable objects in the model.""" 226 var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph() 227 return var_list 228 229 230def create_placeholder(spec): 231 return backend.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) 232 233 234def _export_mode( 235 mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, 236 input_signature): 237 """Exports a model, and optionally saves new vars from the clone model. 238 239 Args: 240 mode: A `tf.estimator.ModeKeys` string. 241 has_saved_vars: A `boolean` indicating whether the SavedModel has already 242 exported variables. 243 builder: A `SavedModelBuilder` object. 244 model: A `tf.keras.Model` object. 245 custom_objects: A dictionary mapping string names to custom classes 246 or functions. 247 checkpoint_path: String path to checkpoint. 248 input_signature: Nested TensorSpec containing the expected inputs. Can be 249 `None`, in which case the signature will be inferred from the model. 250 251 Raises: 252 ValueError: If the train/eval mode is being exported, but the model does 253 not have an optimizer. 254 """ 255 compile_clone = (mode != mode_keys.ModeKeys.PREDICT) 256 if compile_clone and not model.optimizer: 257 raise ValueError( 258 'Model does not have an optimizer. Cannot export mode %s' % mode) 259 260 model_graph = ops.get_default_graph() 261 with ops.Graph().as_default() as g, backend.learning_phase_scope( 262 mode == mode_keys.ModeKeys.TRAIN): 263 264 if input_signature is None: 265 input_tensors = None 266 else: 267 input_tensors = nest.map_structure(create_placeholder, input_signature) 268 269 # Clone the model into blank graph. This will create placeholders for inputs 270 # and targets. 271 clone = models_lib.clone_and_build_model( 272 model, input_tensors=input_tensors, custom_objects=custom_objects, 273 compile_clone=compile_clone) 274 275 # Make sure that iterations variable is added to the global step collection, 276 # to ensure that, when the SavedModel graph is loaded, the iterations 277 # variable is returned by `tf.compat.v1.train.get_global_step()`. This is 278 # required for compatibility with the SavedModelEstimator. 279 if compile_clone: 280 g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) 281 282 # Extract update and train ops from train/test/predict functions. 283 train_op = None 284 if mode == mode_keys.ModeKeys.TRAIN: 285 clone._make_train_function() # pylint: disable=protected-access 286 train_op = clone.train_function.updates_op 287 elif mode == mode_keys.ModeKeys.TEST: 288 clone._make_test_function() # pylint: disable=protected-access 289 else: 290 clone._make_predict_function() # pylint: disable=protected-access 291 g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) 292 293 with session.Session().as_default(): 294 clone_var_list = _get_var_list(clone) 295 if has_saved_vars: 296 # Confirm all variables in the clone have an entry in the checkpoint. 297 status = clone.load_weights(checkpoint_path) 298 status.assert_existing_objects_matched() 299 else: 300 # Confirm that variables between the clone and model match up exactly, 301 # not counting optimizer objects. Optimizer objects are ignored because 302 # if the model has not trained, the slot variables will not have been 303 # created yet. 304 # TODO(b/113179535): Replace with trackable equivalence. 305 _assert_same_non_optimizer_objects(model, model_graph, clone, g) 306 307 # TODO(b/113178242): Use value transfer for trackable objects. 308 clone.load_weights(checkpoint_path) 309 310 # Add graph and variables to SavedModel. 311 # TODO(b/113134168): Switch to add_meta_graph_and_variables. 312 clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) 313 builder._has_saved_variables = True # pylint: disable=protected-access 314 315 # Add graph to the SavedModel builder. 316 builder.add_meta_graph( 317 model_utils.EXPORT_TAG_MAP[mode], 318 signature_def_map=_create_signature_def_map(clone, mode), 319 saver=saver_lib.Saver( 320 clone_var_list, 321 # Allow saving Models with no variables. This is somewhat odd, but 322 # it's not necessarily a bug. 323 allow_empty=True), 324 init_op=variables.local_variables_initializer(), 325 train_op=train_op) 326 return None 327 328 329def _create_signature_def_map(model, mode): 330 """Creates a SignatureDef map from a Keras model.""" 331 inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} 332 if model.optimizer: 333 targets_dict = {x.name.split(':')[0]: x 334 for x in model._targets if x is not None} # pylint: disable=protected-access 335 inputs_dict.update(targets_dict) 336 outputs_dict = {name: x 337 for name, x in zip(model.output_names, model.outputs)} 338 metrics = saving_utils.extract_model_metrics(model) 339 340 # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables 341 # are by default not added to any collections. We are doing this here, so 342 # that metric variables get initialized. 343 local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 344 vars_to_add = set() 345 if metrics is not None: 346 for key, value in metrics.items(): 347 if isinstance(value, metrics_lib.Metric): 348 vars_to_add.update(value.variables) 349 # Convert Metric instances to (value_tensor, update_op) tuple. 350 metrics[key] = (value.result(), value.updates[0]) 351 # Remove variables that are in the local variables collection already. 352 vars_to_add = vars_to_add.difference(local_vars) 353 for v in vars_to_add: 354 ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) 355 356 export_outputs = model_utils.export_outputs_for_mode( 357 mode, 358 predictions=outputs_dict, 359 loss=model.total_loss if model.optimizer else None, 360 metrics=metrics) 361 return model_utils.build_all_signature_defs( 362 inputs_dict, 363 export_outputs=export_outputs, 364 serving_only=(mode == mode_keys.ModeKeys.PREDICT)) 365 366 367def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument 368 """Asserts model and clone contain the same trackable objects.""" 369 370 # TODO(fchollet, kathywu): make sure this works in eager mode. 371 return True 372 373 374@keras_export(v1=['keras.experimental.load_from_saved_model']) 375def load_from_saved_model(saved_model_path, custom_objects=None): 376 """Loads a keras Model from a SavedModel created by `export_saved_model()`. 377 378 This function reinstantiates model state by: 379 1) loading model topology from json (this will eventually come 380 from metagraph). 381 2) loading model weights from checkpoint. 382 383 Example: 384 385 ```python 386 import tensorflow as tf 387 388 # Create a tf.keras model. 389 model = tf.keras.Sequential() 390 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 391 model.summary() 392 393 # Save the tf.keras model in the SavedModel format. 394 path = '/tmp/simple_keras_model' 395 tf.keras.experimental.export_saved_model(model, path) 396 397 # Load the saved keras model back. 398 new_model = tf.keras.experimental.load_from_saved_model(path) 399 new_model.summary() 400 ``` 401 402 Args: 403 saved_model_path: a string specifying the path to an existing SavedModel. 404 custom_objects: Optional dictionary mapping names 405 (strings) to custom classes or functions to be 406 considered during deserialization. 407 408 Returns: 409 a keras.Model instance. 410 """ 411 warnings.warn('`tf.keras.experimental.load_from_saved_model` is deprecated' 412 'and will be removed in a future version. ' 413 'Please switch to `tf.keras.models.load_model`.') 414 # restore model topology from json string 415 model_json_filepath = os.path.join( 416 compat.as_bytes(saved_model_path), 417 compat.as_bytes(constants.ASSETS_DIRECTORY), 418 compat.as_bytes(SAVED_MODEL_FILENAME_JSON)) 419 with gfile.Open(model_json_filepath, 'r') as f: 420 model_json = f.read() 421 model = model_config.model_from_json( 422 model_json, custom_objects=custom_objects) 423 424 # restore model weights 425 checkpoint_prefix = os.path.join( 426 compat.as_text(saved_model_path), 427 compat.as_text(constants.VARIABLES_DIRECTORY), 428 compat.as_text(constants.VARIABLES_FILENAME)) 429 model.load_weights(checkpoint_prefix) 430 return model 431 432 433#### Directory / path helpers 434 435 436def _get_or_create_variables_dir(export_dir): 437 """Return variables sub-directory, or create one if it doesn't exist.""" 438 variables_dir = _get_variables_dir(export_dir) 439 file_io.recursive_create_dir(variables_dir) 440 return variables_dir 441 442 443def _get_variables_dir(export_dir): 444 """Return variables sub-directory in the SavedModel.""" 445 return os.path.join( 446 compat.as_text(export_dir), 447 compat.as_text(constants.VARIABLES_DIRECTORY)) 448 449 450def _get_variables_path(export_dir): 451 """Return the variables path, used as the prefix for checkpoint files.""" 452 return os.path.join( 453 compat.as_text(_get_variables_dir(export_dir)), 454 compat.as_text(constants.VARIABLES_FILENAME)) 455 456 457def _get_or_create_assets_dir(export_dir): 458 """Return assets sub-directory, or create one if it doesn't exist.""" 459 assets_destination_dir = _get_assets_dir(export_dir) 460 461 file_io.recursive_create_dir(assets_destination_dir) 462 463 return assets_destination_dir 464 465 466def _get_assets_dir(export_dir): 467 """Return path to asset directory in the SavedModel.""" 468 return os.path.join( 469 compat.as_text(export_dir), 470 compat.as_text(constants.ASSETS_DIRECTORY)) 471