1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils related to keras model saving.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections.abc as collections_abc 21import copy 22import os 23import six 24 25from tensorflow.python.eager import def_function 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras import losses 28from tensorflow.python.keras import optimizer_v1 29from tensorflow.python.keras import optimizers 30from tensorflow.python.keras.engine import base_layer_utils 31from tensorflow.python.keras.utils import generic_utils 32from tensorflow.python.keras.utils import version_utils 33from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.util import nest 36 37 38def extract_model_metrics(model): 39 """Convert metrics from a Keras model `compile` API to dictionary. 40 41 This is used for converting Keras models to Estimators and SavedModels. 42 43 Args: 44 model: A `tf.keras.Model` object. 45 46 Returns: 47 Dictionary mapping metric names to metric instances. May return `None` if 48 the model does not contain any metrics. 49 """ 50 if getattr(model, '_compile_metrics', None): 51 # TODO(psv/kathywu): use this implementation in model to estimator flow. 52 # We are not using model.metrics here because we want to exclude the metrics 53 # added using `add_metric` API. 54 return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access 55 return None 56 57 58def model_input_signature(model, keep_original_batch_size=False): 59 """Inspect model to get its input signature. 60 61 The model's input signature is a list with a single (possibly-nested) object. 62 This is due to the Keras-enforced restriction that tensor inputs must be 63 passed in as the first argument. 64 65 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} 66 will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] 67 68 Args: 69 model: Keras Model object. 70 keep_original_batch_size: A boolean indicating whether we want to keep using 71 the original batch size or set it to None. Default is `False`, which means 72 that the batch dim of the returned input signature will always be set to 73 `None`. 74 75 Returns: 76 A list containing either a single TensorSpec or an object with nested 77 TensorSpecs. This list does not contain the `training` argument. 78 """ 79 input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access 80 if input_specs is None: 81 return None 82 input_specs = _enforce_names_consistency(input_specs) 83 # Return a list with a single element as the model's input signature. 84 if isinstance(input_specs, 85 collections_abc.Sequence) and len(input_specs) == 1: 86 # Note that the isinstance check filters out single-element dictionaries, 87 # which should also be wrapped as a single-element list. 88 return input_specs 89 else: 90 return [input_specs] 91 92 93def raise_model_input_error(model): 94 raise ValueError( 95 'Model {} cannot be saved because the input shapes have not been ' 96 'set. Usually, input shapes are automatically determined from calling' 97 ' `.fit()` or `.predict()`. To manually set the shapes, call ' 98 '`model.build(input_shape)`.'.format(model)) 99 100 101def trace_model_call(model, input_signature=None): 102 """Trace the model call to create a tf.function for exporting a Keras model. 103 104 Args: 105 model: A Keras model. 106 input_signature: optional, a list of tf.TensorSpec objects specifying the 107 inputs to the model. 108 109 Returns: 110 A tf.function wrapping the model's call function with input signatures set. 111 112 Raises: 113 ValueError: if input signature cannot be inferred from the model. 114 """ 115 if input_signature is None: 116 if isinstance(model.call, def_function.Function): 117 input_signature = model.call.input_signature 118 119 if input_signature is None: 120 input_signature = model_input_signature(model) 121 122 if input_signature is None: 123 raise_model_input_error(model) 124 125 @def_function.function(input_signature=input_signature) 126 def _wrapped_model(*args): 127 """A concrete tf.function that wraps the model's call function.""" 128 # When given a single input, Keras models will call the model on the tensor 129 # rather than a list consisting of the single tensor. 130 inputs = args[0] if len(input_signature) == 1 else list(args) 131 132 with base_layer_utils.call_context().enter( 133 model, inputs=inputs, build_graph=False, training=False, saving=True): 134 outputs = model(inputs, training=False) 135 136 # Outputs always has to be a flat dict. 137 output_names = model.output_names # Functional Model. 138 if output_names is None: # Subclassed Model. 139 from tensorflow.python.keras.engine import compile_utils # pylint: disable=g-import-not-at-top 140 output_names = compile_utils.create_pseudo_output_names(outputs) 141 outputs = nest.flatten(outputs) 142 return {name: output for name, output in zip(output_names, outputs)} 143 144 return _wrapped_model 145 146 147def model_metadata(model, include_optimizer=True, require_config=True): 148 """Returns a dictionary containing the model metadata.""" 149 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 150 from tensorflow.python.keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top 151 152 model_config = {'class_name': model.__class__.__name__} 153 try: 154 model_config['config'] = model.get_config() 155 except NotImplementedError as e: 156 if require_config: 157 raise e 158 159 metadata = dict( 160 keras_version=str(keras_version), 161 backend=K.backend(), 162 model_config=model_config) 163 if model.optimizer and include_optimizer: 164 if isinstance(model.optimizer, optimizer_v1.TFOptimizer): 165 logging.warning( 166 'TensorFlow optimizers do not ' 167 'make it possible to access ' 168 'optimizer attributes or optimizer state ' 169 'after instantiation. ' 170 'As a result, we cannot save the optimizer ' 171 'as part of the model save file. ' 172 'You will have to compile your model again after loading it. ' 173 'Prefer using a Keras optimizer instead ' 174 '(see keras.io/optimizers).') 175 elif model._compile_was_called: # pylint: disable=protected-access 176 training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access 177 training_config.pop('optimizer', None) # Handled separately. 178 metadata['training_config'] = _serialize_nested_config(training_config) 179 if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): 180 raise NotImplementedError( 181 'As of now, Optimizers loaded from SavedModel cannot be saved. ' 182 'If you\'re calling `model.save` or `tf.keras.models.save_model`,' 183 ' please set the `include_optimizer` option to `False`. For ' 184 '`tf.saved_model.save`, delete the optimizer from the model.') 185 else: 186 optimizer_config = { 187 'class_name': 188 generic_utils.get_registered_name(model.optimizer.__class__), 189 'config': 190 model.optimizer.get_config() 191 } 192 metadata['training_config']['optimizer_config'] = optimizer_config 193 return metadata 194 195 196def should_overwrite(filepath, overwrite): 197 """Returns whether the filepath should be overwritten.""" 198 # If file exists and should not be overwritten. 199 if not overwrite and os.path.isfile(filepath): 200 return ask_to_proceed_with_overwrite(filepath) 201 return True 202 203 204def compile_args_from_training_config(training_config, custom_objects=None): 205 """Return model.compile arguments from training config.""" 206 if custom_objects is None: 207 custom_objects = {} 208 209 with generic_utils.CustomObjectScope(custom_objects): 210 optimizer_config = training_config['optimizer_config'] 211 optimizer = optimizers.deserialize(optimizer_config) 212 213 # Recover losses. 214 loss = None 215 loss_config = training_config.get('loss', None) 216 if loss_config is not None: 217 loss = _deserialize_nested_config(losses.deserialize, loss_config) 218 219 # Recover metrics. 220 metrics = None 221 metrics_config = training_config.get('metrics', None) 222 if metrics_config is not None: 223 metrics = _deserialize_nested_config(_deserialize_metric, metrics_config) 224 225 # Recover weighted metrics. 226 weighted_metrics = None 227 weighted_metrics_config = training_config.get('weighted_metrics', None) 228 if weighted_metrics_config is not None: 229 weighted_metrics = _deserialize_nested_config(_deserialize_metric, 230 weighted_metrics_config) 231 232 sample_weight_mode = training_config['sample_weight_mode'] if hasattr( 233 training_config, 'sample_weight_mode') else None 234 loss_weights = training_config['loss_weights'] 235 236 return dict( 237 optimizer=optimizer, 238 loss=loss, 239 metrics=metrics, 240 weighted_metrics=weighted_metrics, 241 loss_weights=loss_weights, 242 sample_weight_mode=sample_weight_mode) 243 244 245def _deserialize_nested_config(deserialize_fn, config): 246 """Deserializes arbitrary Keras `config` using `deserialize_fn`.""" 247 248 def _is_single_object(obj): 249 if isinstance(obj, dict) and 'class_name' in obj: 250 return True # Serialized Keras object. 251 if isinstance(obj, six.string_types): 252 return True # Serialized function or string. 253 return False 254 255 if config is None: 256 return None 257 if _is_single_object(config): 258 return deserialize_fn(config) 259 elif isinstance(config, dict): 260 return { 261 k: _deserialize_nested_config(deserialize_fn, v) 262 for k, v in config.items() 263 } 264 elif isinstance(config, (tuple, list)): 265 return [_deserialize_nested_config(deserialize_fn, obj) for obj in config] 266 267 raise ValueError('Saved configuration not understood.') 268 269 270def _serialize_nested_config(config): 271 """Serialized a nested structure of Keras objects.""" 272 273 def _serialize_fn(obj): 274 if callable(obj): 275 return generic_utils.serialize_keras_object(obj) 276 return obj 277 278 return nest.map_structure(_serialize_fn, config) 279 280 281def _deserialize_metric(metric_config): 282 """Deserialize metrics, leaving special strings untouched.""" 283 from tensorflow.python.keras import metrics as metrics_module # pylint:disable=g-import-not-at-top 284 if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']: 285 # Do not deserialize accuracy and cross-entropy strings as we have special 286 # case handling for these in compile, based on model output shape. 287 return metric_config 288 return metrics_module.deserialize(metric_config) 289 290 291def _enforce_names_consistency(specs): 292 """Enforces that either all specs have names or none do.""" 293 294 def _has_name(spec): 295 return hasattr(spec, 'name') and spec.name is not None 296 297 def _clear_name(spec): 298 spec = copy.deepcopy(spec) 299 if hasattr(spec, 'name'): 300 spec._name = None # pylint:disable=protected-access 301 return spec 302 303 flat_specs = nest.flatten(specs) 304 name_inconsistency = ( 305 any(_has_name(s) for s in flat_specs) and 306 not all(_has_name(s) for s in flat_specs)) 307 308 if name_inconsistency: 309 specs = nest.map_structure(_clear_name, specs) 310 return specs 311 312 313def try_build_compiled_arguments(model): 314 if (not version_utils.is_v1_layer_or_model(model) and 315 model.outputs is not None): 316 try: 317 model.compiled_loss.build(model.outputs) 318 model.compiled_metrics.build(model.outputs, model.outputs) 319 except: # pylint: disable=bare-except 320 logging.warning( 321 'Compiled the loaded model, but the compiled metrics have yet to ' 322 'be built. `model.compile_metrics` will be empty until you train ' 323 'or evaluate the model.') 324 325 326def is_hdf5_filepath(filepath): 327 return (filepath.endswith('.h5') or filepath.endswith('.keras') or 328 filepath.endswith('.hdf5')) 329