• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utility functions to save/load keras Model to/from SavedModel."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import six
23
24from tensorflow.python.client import session
25from tensorflow.python.framework import ops
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras import optimizers
28from tensorflow.python.keras.optimizer_v2 import optimizer_v2
29from tensorflow.python.keras.saving import model_from_json
30from tensorflow.python.keras.saving import saving_utils
31from tensorflow.python.keras.utils import mode_keys
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.saved_model import builder as saved_model_builder
36from tensorflow.python.saved_model import constants
37from tensorflow.python.saved_model import model_utils
38from tensorflow.python.saved_model import save as save_lib
39from tensorflow.python.saved_model import utils_impl as saved_model_utils
40from tensorflow.python.training import saver as saver_lib
41from tensorflow.python.training.tracking import graph_view
42from tensorflow.python.util import compat
43from tensorflow.python.util import nest
44from tensorflow.python.util.tf_export import keras_export
45
46
47@keras_export('keras.experimental.export_saved_model')
48def export_saved_model(model,
49                       saved_model_path,
50                       custom_objects=None,
51                       as_text=False,
52                       input_signature=None,
53                       serving_only=False):
54  """Exports a `tf.keras.Model` as a Tensorflow SavedModel.
55
56  Note that at this time, subclassed models can only be saved using
57  `serving_only=True`.
58
59  The exported `SavedModel` is a standalone serialization of Tensorflow objects,
60  and is supported by TF language APIs and the Tensorflow Serving system.
61  To load the model, use the function
62  `tf.keras.experimental.load_from_saved_model`.
63
64  The `SavedModel` contains:
65
66  1. a checkpoint containing the model weights.
67  2. a `SavedModel` proto containing the Tensorflow backend graph. Separate
68     graphs are saved for prediction (serving), train, and evaluation. If
69     the model has not been compiled, then only the graph computing predictions
70     will be exported.
71  3. the model's json config. If the model is subclassed, this will only be
72     included if the model's `get_config()` method is overwritten.
73
74  Example:
75
76  ```python
77  import tensorflow as tf
78
79  # Create a tf.keras model.
80  model = tf.keras.Sequential()
81  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
82  model.summary()
83
84  # Save the tf.keras model in the SavedModel format.
85  path = '/tmp/simple_keras_model'
86  tf.keras.experimental.export_saved_model(model, path)
87
88  # Load the saved keras model back.
89  new_model = tf.keras.experimental.load_from_saved_model(path)
90  new_model.summary()
91  ```
92
93  Args:
94    model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag
95      `serving_only` must be set to True.
96    saved_model_path: a string specifying the path to the SavedModel directory.
97    custom_objects: Optional dictionary mapping string names to custom classes
98      or functions (e.g. custom loss functions).
99    as_text: bool, `False` by default. Whether to write the `SavedModel` proto
100      in text format. Currently unavailable in serving-only mode.
101    input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used
102      to specify the expected model inputs. See `tf.function` for more details.
103    serving_only: bool, `False` by default. When this is true, only the
104      prediction graph is saved.
105
106  Raises:
107    NotImplementedError: If the model is a subclassed model, and serving_only is
108      False.
109    ValueError: If the input signature cannot be inferred from the model.
110    AssertionError: If the SavedModel directory already exists and isn't empty.
111  """
112  if serving_only:
113    save_lib.save(
114        model,
115        saved_model_path,
116        signatures=saving_utils.trace_model_call(model, input_signature))
117  else:
118    _save_v1_format(model, saved_model_path, custom_objects, as_text,
119                    input_signature)
120
121  try:
122    _export_model_json(model, saved_model_path)
123  except NotImplementedError:
124    logging.warning('Skipped saving model JSON, subclassed model does not have '
125                    'get_config() defined.')
126
127
128def _export_model_json(model, saved_model_path):
129  """Saves model configuration as a json string under assets folder."""
130  model_json = model.to_json()
131  model_json_filepath = os.path.join(
132      saved_model_utils.get_or_create_assets_dir(saved_model_path),
133      compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
134  file_io.write_string_to_file(model_json_filepath, model_json)
135
136
137def _export_model_variables(model, saved_model_path):
138  """Saves model weights in checkpoint format under variables folder."""
139  saved_model_utils.get_or_create_variables_dir(saved_model_path)
140  checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
141  model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
142  return checkpoint_prefix
143
144
145def _save_v1_format(model, path, custom_objects, as_text, input_signature):
146  """Exports model to v1 SavedModel format."""
147  from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top
148
149  if not model._is_graph_network:
150    if isinstance(model, sequential.Sequential):
151      # If input shape is not directly set in the model, the exported model
152      # will infer the expected shapes of the input from the model.
153      if not model.built and input_signature is None:
154        raise ValueError(
155            'Sequential model\'s input shape is unknown. Please build the '
156            'model, or use the input_signature argument to specify the '
157            'model inputs.')
158    else:
159      raise NotImplementedError(
160          'Subclassed models can only be exported for serving. Please set '
161          'argument serving_only=True.')
162
163  builder = saved_model_builder._SavedModelBuilder(path)
164
165  # Manually save variables to export them in an object-based checkpoint. This
166  # skips the `builder.add_meta_graph_and_variables()` step, which saves a
167  # named-based checkpoint.
168  # TODO(b/113134168): Add fn to Builder to save with object-based saver.
169  # TODO(b/113178242): This should only export the model json structure. Only
170  # one save is needed once the weights can be copied from the model to clone.
171  checkpoint_path = _export_model_variables(model, path)
172
173  # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
174  # Keras models and `Estimator`s are exported with the same format.
175  # Every time a mode is exported, the code checks to see if new variables have
176  # been created (e.g. optimizer slot variables). If that is the case, the
177  # checkpoint is re-saved to include the new variables.
178  export_args = {'builder': builder,
179                 'model': model,
180                 'custom_objects': custom_objects,
181                 'checkpoint_path': checkpoint_path,
182                 'input_signature': input_signature}
183
184  has_saved_vars = False
185  if model.optimizer:
186    if isinstance(model.optimizer, (optimizers.TFOptimizer,
187                                    optimizer_v2.OptimizerV2)):
188      _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
189      has_saved_vars = True
190      _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
191    else:
192      logging.warning(
193          'Model was compiled with an optimizer, but the optimizer is not from '
194          '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
195          'graph was exported. The train and evaluate graphs were not added to '
196          'the SavedModel.')
197  _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args)
198
199  builder.save(as_text)
200
201
202def _get_var_list(model):
203  """Returns list of all checkpointed saveable objects in the model."""
204  var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
205  return var_list
206
207
208def create_placeholder(spec):
209  return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name)
210
211
212def _export_mode(
213    mode, has_saved_vars, builder, model, custom_objects, checkpoint_path,
214    input_signature):
215  """Exports a model, and optionally saves new vars from the clone model.
216
217  Args:
218    mode: A `tf.estimator.ModeKeys` string.
219    has_saved_vars: A `boolean` indicating whether the SavedModel has already
220      exported variables.
221    builder: A `SavedModelBuilder` object.
222    model: A `tf.keras.Model` object.
223    custom_objects: A dictionary mapping string names to custom classes
224      or functions.
225    checkpoint_path: String path to checkpoint.
226    input_signature: Nested TensorSpec containing the expected inputs. Can be
227      `None`, in which case the signature will be inferred from the model.
228
229  Raises:
230    ValueError: If the train/eval mode is being exported, but the model does
231      not have an optimizer.
232  """
233  from tensorflow.python.keras import models as models_lib  # pylint: disable=g-import-not-at-top
234  compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
235  if compile_clone and not model.optimizer:
236    raise ValueError(
237        'Model does not have an optimizer. Cannot export mode %s' % mode)
238
239  model_graph = ops.get_default_graph()
240  with ops.Graph().as_default() as g, K.learning_phase_scope(
241      mode == mode_keys.ModeKeys.TRAIN):
242
243    if input_signature is None:
244      input_tensors = None
245    else:
246      input_tensors = nest.map_structure(create_placeholder, input_signature)
247
248    # Clone the model into blank graph. This will create placeholders for inputs
249    # and targets.
250    clone = models_lib.clone_and_build_model(
251        model, input_tensors=input_tensors, custom_objects=custom_objects,
252        compile_clone=compile_clone)
253
254    # Make sure that iterations variable is added to the global step collection,
255    # to ensure that, when the SavedModel graph is loaded, the iterations
256    # variable is returned by `tf.train.get_global_step()`. This is required for
257    # compatibility with the SavedModelEstimator.
258    if compile_clone:
259      g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
260
261    # Extract update and train ops from train/test/predict functions.
262    train_op = None
263    if mode == mode_keys.ModeKeys.TRAIN:
264      clone._make_train_function()
265      train_op = clone.train_function.updates_op
266    elif mode == mode_keys.ModeKeys.TEST:
267      clone._make_test_function()
268    else:
269      clone._make_predict_function()
270    g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
271
272    with session.Session().as_default():
273      clone_var_list = _get_var_list(clone)
274      if has_saved_vars:
275        # Confirm all variables in the clone have an entry in the checkpoint.
276        status = clone.load_weights(checkpoint_path)
277        status.assert_existing_objects_matched()
278      else:
279        # Confirm that variables between the clone and model match up exactly,
280        # not counting optimizer objects. Optimizer objects are ignored because
281        # if the model has not trained, the slot variables will not have been
282        # created yet.
283        # TODO(b/113179535): Replace with trackable equivalence.
284        _assert_same_non_optimizer_objects(model, model_graph, clone, g)
285
286        # TODO(b/113178242): Use value transfer for trackable objects.
287        clone.load_weights(checkpoint_path)
288
289        # Add graph and variables to SavedModel.
290        # TODO(b/113134168): Switch to add_meta_graph_and_variables.
291        clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
292        builder._has_saved_variables = True
293
294      # Add graph to the SavedModel builder.
295      builder.add_meta_graph(
296          model_utils.EXPORT_TAG_MAP[mode],
297          signature_def_map=_create_signature_def_map(clone, mode),
298          saver=saver_lib.Saver(clone_var_list),
299          init_op=variables.local_variables_initializer(),
300          train_op=train_op)
301    return None
302
303
304def _create_signature_def_map(model, mode):
305  """Creates a SignatureDef map from a Keras model."""
306  inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
307  if model.optimizer:
308    targets_dict = {x.name.split(':')[0]: x
309                    for x in model.targets if x is not None}
310    inputs_dict.update(targets_dict)
311  outputs_dict = {name: x
312                  for name, x in zip(model.output_names, model.outputs)}
313  metrics = saving_utils.extract_model_metrics(model)
314
315  # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
316  # are by default not added to any collections. We are doing this here, so
317  # that metric variables get initialized.
318  local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
319  vars_to_add = set()
320  if metrics is not None:
321    from tensorflow.python.keras.metrics import Metric  # pylint: disable=g-import-not-at-top
322    for key, value in six.iteritems(metrics):
323      if isinstance(value, Metric):
324        vars_to_add.update(value.variables)
325        # Convert Metric instances to (value_tensor, update_op) tuple.
326        metrics[key] = (value.result(), value.updates[0])
327  # Remove variables that are in the local variables collection already.
328  vars_to_add = vars_to_add.difference(local_vars)
329  for v in vars_to_add:
330    ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
331
332  export_outputs = model_utils.export_outputs_for_mode(
333      mode,
334      predictions=outputs_dict,
335      loss=model.total_loss if model.optimizer else None,
336      metrics=metrics)
337  return model_utils.build_all_signature_defs(
338      inputs_dict,
339      export_outputs=export_outputs,
340      serving_only=(mode == mode_keys.ModeKeys.PREDICT))
341
342
343def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):  # pylint: disable=unused-argument
344  """Asserts model and clone contain the same trackable objects."""
345
346  # TODO(fchollet, kathywu): make sure this works in eager mode.
347  return True
348
349
350@keras_export('keras.experimental.load_from_saved_model')
351def load_from_saved_model(saved_model_path, custom_objects=None):
352  """Loads a keras Model from a SavedModel created by `export_saved_model()`.
353
354  This function reinstantiates model state by:
355  1) loading model topology from json (this will eventually come
356     from metagraph).
357  2) loading model weights from checkpoint.
358
359  Example:
360
361  ```python
362  import tensorflow as tf
363
364  # Create a tf.keras model.
365  model = tf.keras.Sequential()
366  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
367  model.summary()
368
369  # Save the tf.keras model in the SavedModel format.
370  path = '/tmp/simple_keras_model'
371  tf.keras.experimental.export_saved_model(model, path)
372
373  # Load the saved keras model back.
374  new_model = tf.keras.experimental.load_from_saved_model(path)
375  new_model.summary()
376  ```
377
378  Args:
379    saved_model_path: a string specifying the path to an existing SavedModel.
380    custom_objects: Optional dictionary mapping names
381        (strings) to custom classes or functions to be
382        considered during deserialization.
383
384  Returns:
385    a keras.Model instance.
386  """
387  # restore model topology from json string
388  model_json_filepath = os.path.join(
389      compat.as_bytes(saved_model_path),
390      compat.as_bytes(constants.ASSETS_DIRECTORY),
391      compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
392  model_json = file_io.read_file_to_string(model_json_filepath)
393  model = model_from_json(model_json, custom_objects=custom_objects)
394
395  # restore model weights
396  checkpoint_prefix = os.path.join(
397      compat.as_text(saved_model_path),
398      compat.as_text(constants.VARIABLES_DIRECTORY),
399      compat.as_text(constants.VARIABLES_FILENAME))
400  model.load_weights(checkpoint_prefix)
401  return model
402