1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras SavedModel serialization.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21 22from tensorflow.core.framework import versions_pb2 23from tensorflow.python.distribute import distribution_strategy_context 24from tensorflow.python.keras import backend as K 25from tensorflow.python.keras.protobuf import saved_metadata_pb2 26from tensorflow.python.keras.saving import saving_utils 27from tensorflow.python.keras.saving.saved_model import constants 28from tensorflow.python.keras.saving.saved_model import save_impl 29from tensorflow.python.keras.saving.saved_model import utils 30from tensorflow.python.keras.utils.generic_utils import LazyLoader 31from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 32from tensorflow.python.platform import gfile 33from tensorflow.python.saved_model import save as save_lib 34 35 36# To avoid circular dependencies between keras/engine and keras/saving, 37# code in keras/saving must delay imports. 38 39base_layer = LazyLoader( 40 "base_layer", globals(), 41 "tensorflow.python.keras.engine.base_layer") 42training_lib = LazyLoader( 43 "training_lib", globals(), 44 "tensorflow.python.keras.engine.training") 45 46 47def save(model, filepath, overwrite, include_optimizer, signatures=None, 48 options=None, save_traces=True): 49 """Saves a model as a SavedModel to the filepath. 50 51 Args: 52 model: Keras model instance to be saved. 53 filepath: String path to save the model. 54 overwrite: whether to overwrite the existing filepath. 55 include_optimizer: If True, save the model's optimizer state. 56 signatures: Signatures to save with the SavedModel. Applicable to the 'tf' 57 format only. Please see the `signatures` argument in `tf.saved_model.save` 58 for details. 59 options: (only applies to SavedModel format) `tf.saved_model.SaveOptions` 60 object that specifies options for saving to SavedModel. 61 save_traces: (only applies to SavedModel format) When enabled, the 62 SavedModel will store the function traces for each layer. This 63 can be disabled, so that only the configs of each layer are stored. 64 Defaults to `True`. Disabling this will decrease serialization time 65 and reduce file size, but it requires that all custom layers/models 66 implement a `get_config()` method. 67 68 Raises: 69 ValueError: if the model's inputs have not been defined. 70 """ 71 # If file exists and should not be overwritten. 72 if not overwrite and os.path.exists(filepath): 73 proceed = ask_to_proceed_with_overwrite(filepath) 74 if not proceed: 75 return 76 77 if save_traces: 78 if save_impl.should_skip_serialization(model): 79 saving_utils.raise_model_input_error(model) 80 81 if not include_optimizer: 82 orig_optimizer = model.optimizer 83 model.optimizer = None 84 85 # Trace all functions and signatures with `training=0` instead of using an 86 # already-set learning phase placeholder. 87 # This is needed for compatibility reasons until learning phase setting 88 # is removed from the public apis. 89 with K.deprecated_internal_learning_phase_scope(0): 90 # When saving a model involving batch norm layer within a strategy scope, 91 # the replica context is not available when calling `add_update()`, and thus 92 # we use the default replica context here. 93 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access 94 with utils.keras_option_scope(save_traces): 95 saved_nodes, node_paths = save_lib.save_and_return_nodes( 96 model, filepath, signatures, options) 97 98 # Save all metadata to a separate file in the SavedModel directory. 99 metadata = generate_keras_metadata(saved_nodes, node_paths) 100 101 with gfile.GFile( 102 os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w: 103 w.write(metadata.SerializeToString(deterministic=True)) 104 105 if not include_optimizer: 106 model.optimizer = orig_optimizer 107 108 109def generate_keras_metadata(saved_nodes, node_paths): 110 """Constructs a KerasMetadata proto with the metadata of each keras object.""" 111 metadata = saved_metadata_pb2.SavedMetadata() 112 113 for node_id, node in enumerate(saved_nodes): 114 if isinstance(node, base_layer.Layer): 115 path = node_paths[node] 116 if not path: 117 node_path = "root" 118 else: 119 node_path = "root.{}".format( 120 ".".join([ref.name for ref in path])) 121 122 metadata.nodes.add( 123 node_id=node_id, 124 node_path=node_path, 125 version=versions_pb2.VersionDef( 126 producer=1, min_consumer=1, bad_consumers=[]), 127 identifier=node._object_identifier, # pylint: disable=protected-access 128 metadata=node._tracking_metadata) # pylint: disable=protected-access 129 130 return metadata 131