• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Deprecated experimental Keras SavedModel implementation."""
16
17import os
18import warnings
19
20from tensorflow.python.client import session
21from tensorflow.python.framework import ops
22from tensorflow.python.keras import backend
23from tensorflow.python.keras import optimizer_v1
24from tensorflow.python.keras.optimizer_v2 import optimizer_v2
25from tensorflow.python.keras.saving import model_config
26from tensorflow.python.keras.saving import saving_utils
27from tensorflow.python.keras.saving import utils_v1 as model_utils
28from tensorflow.python.keras.utils import mode_keys
29from tensorflow.python.keras.utils.generic_utils import LazyLoader
30from tensorflow.python.lib.io import file_io
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import gfile
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.saved_model import builder as saved_model_builder
35from tensorflow.python.saved_model import constants
36from tensorflow.python.saved_model import save as save_lib
37from tensorflow.python.training import saver as saver_lib
38from tensorflow.python.training.tracking import graph_view
39from tensorflow.python.util import compat
40from tensorflow.python.util import nest
41from tensorflow.python.util.tf_export import keras_export
42
43# To avoid circular dependencies between keras/engine and keras/saving,
44# code in keras/saving must delay imports.
45
46# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
47# once the issue with copybara is fixed.
48# pylint:disable=g-inconsistent-quotes
49metrics_lib = LazyLoader("metrics_lib", globals(),
50                         "tensorflow.python.keras.metrics")
51models_lib = LazyLoader("models_lib", globals(),
52                        "tensorflow.python.keras.models")
53sequential = LazyLoader(
54    "sequential", globals(),
55    "tensorflow.python.keras.engine.sequential")
56# pylint:enable=g-inconsistent-quotes
57
58
59# File name for json format of SavedModel.
60SAVED_MODEL_FILENAME_JSON = 'saved_model.json'
61
62
63@keras_export(v1=['keras.experimental.export_saved_model'])
64def export_saved_model(model,
65                       saved_model_path,
66                       custom_objects=None,
67                       as_text=False,
68                       input_signature=None,
69                       serving_only=False):
70  """Exports a `tf.keras.Model` as a Tensorflow SavedModel.
71
72  Note that at this time, subclassed models can only be saved using
73  `serving_only=True`.
74
75  The exported `SavedModel` is a standalone serialization of Tensorflow objects,
76  and is supported by TF language APIs and the Tensorflow Serving system.
77  To load the model, use the function
78  `tf.keras.experimental.load_from_saved_model`.
79
80  The `SavedModel` contains:
81
82  1. a checkpoint containing the model weights.
83  2. a `SavedModel` proto containing the Tensorflow backend graph. Separate
84     graphs are saved for prediction (serving), train, and evaluation. If
85     the model has not been compiled, then only the graph computing predictions
86     will be exported.
87  3. the model's json config. If the model is subclassed, this will only be
88     included if the model's `get_config()` method is overwritten.
89
90  Example:
91
92  ```python
93  import tensorflow as tf
94
95  # Create a tf.keras model.
96  model = tf.keras.Sequential()
97  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
98  model.summary()
99
100  # Save the tf.keras model in the SavedModel format.
101  path = '/tmp/simple_keras_model'
102  tf.keras.experimental.export_saved_model(model, path)
103
104  # Load the saved keras model back.
105  new_model = tf.keras.experimental.load_from_saved_model(path)
106  new_model.summary()
107  ```
108
109  Args:
110    model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag
111      `serving_only` must be set to True.
112    saved_model_path: a string specifying the path to the SavedModel directory.
113    custom_objects: Optional dictionary mapping string names to custom classes
114      or functions (e.g. custom loss functions).
115    as_text: bool, `False` by default. Whether to write the `SavedModel` proto
116      in text format. Currently unavailable in serving-only mode.
117    input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used
118      to specify the expected model inputs. See `tf.function` for more details.
119    serving_only: bool, `False` by default. When this is true, only the
120      prediction graph is saved.
121
122  Raises:
123    NotImplementedError: If the model is a subclassed model, and serving_only is
124      False.
125    ValueError: If the input signature cannot be inferred from the model.
126    AssertionError: If the SavedModel directory already exists and isn't empty.
127  """
128  warnings.warn('`tf.keras.experimental.export_saved_model` is deprecated'
129                'and will be removed in a future version. '
130                'Please use `model.save(..., save_format="tf")` or '
131                '`tf.keras.models.save_model(..., save_format="tf")`.')
132  if serving_only:
133    save_lib.save(
134        model,
135        saved_model_path,
136        signatures=saving_utils.trace_model_call(model, input_signature))
137  else:
138    _save_v1_format(model, saved_model_path, custom_objects, as_text,
139                    input_signature)
140
141  try:
142    _export_model_json(model, saved_model_path)
143  except NotImplementedError:
144    logging.warning('Skipped saving model JSON, subclassed model does not have '
145                    'get_config() defined.')
146
147
148def _export_model_json(model, saved_model_path):
149  """Saves model configuration as a json string under assets folder."""
150  model_json = model.to_json()
151  model_json_filepath = os.path.join(
152      _get_or_create_assets_dir(saved_model_path),
153      compat.as_text(SAVED_MODEL_FILENAME_JSON))
154  with gfile.Open(model_json_filepath, 'w') as f:
155    f.write(model_json)
156
157
158def _export_model_variables(model, saved_model_path):
159  """Saves model weights in checkpoint format under variables folder."""
160  _get_or_create_variables_dir(saved_model_path)
161  checkpoint_prefix = _get_variables_path(saved_model_path)
162  model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
163  return checkpoint_prefix
164
165
166def _save_v1_format(model, path, custom_objects, as_text, input_signature):
167  """Exports model to v1 SavedModel format."""
168  if not model._is_graph_network:  # pylint: disable=protected-access
169    if isinstance(model, sequential.Sequential):
170      # If input shape is not directly set in the model, the exported model
171      # will infer the expected shapes of the input from the model.
172      if not model.built:
173        raise ValueError('Weights for sequential model have not yet been '
174                         'created. Weights are created when the Model is first '
175                         'called on inputs or `build()` is called with an '
176                         '`input_shape`, or the first layer in the model has '
177                         '`input_shape` during construction.')
178      # TODO(kathywu): Build the model with input_signature to create the
179      # weights before _export_model_variables().
180    else:
181      raise NotImplementedError(
182          'Subclassed models can only be exported for serving. Please set '
183          'argument serving_only=True.')
184
185  builder = saved_model_builder._SavedModelBuilder(path)  # pylint: disable=protected-access
186
187  # Manually save variables to export them in an object-based checkpoint. This
188  # skips the `builder.add_meta_graph_and_variables()` step, which saves a
189  # named-based checkpoint.
190  # TODO(b/113134168): Add fn to Builder to save with object-based saver.
191  # TODO(b/113178242): This should only export the model json structure. Only
192  # one save is needed once the weights can be copied from the model to clone.
193  checkpoint_path = _export_model_variables(model, path)
194
195  # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
196  # Keras models and `Estimator`s are exported with the same format.
197  # Every time a mode is exported, the code checks to see if new variables have
198  # been created (e.g. optimizer slot variables). If that is the case, the
199  # checkpoint is re-saved to include the new variables.
200  export_args = {'builder': builder,
201                 'model': model,
202                 'custom_objects': custom_objects,
203                 'checkpoint_path': checkpoint_path,
204                 'input_signature': input_signature}
205
206  has_saved_vars = False
207  if model.optimizer:
208    if isinstance(model.optimizer, (optimizer_v1.TFOptimizer,
209                                    optimizer_v2.OptimizerV2)):
210      _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
211      has_saved_vars = True
212      _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
213    else:
214      logging.warning(
215          'Model was compiled with an optimizer, but the optimizer is not from '
216          '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
217          'graph was exported. The train and evaluate graphs were not added to '
218          'the SavedModel.')
219  _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args)
220
221  builder.save(as_text)
222
223
224def _get_var_list(model):
225  """Returns list of all checkpointed saveable objects in the model."""
226  var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
227  return var_list
228
229
230def create_placeholder(spec):
231  return backend.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name)
232
233
234def _export_mode(
235    mode, has_saved_vars, builder, model, custom_objects, checkpoint_path,
236    input_signature):
237  """Exports a model, and optionally saves new vars from the clone model.
238
239  Args:
240    mode: A `tf.estimator.ModeKeys` string.
241    has_saved_vars: A `boolean` indicating whether the SavedModel has already
242      exported variables.
243    builder: A `SavedModelBuilder` object.
244    model: A `tf.keras.Model` object.
245    custom_objects: A dictionary mapping string names to custom classes
246      or functions.
247    checkpoint_path: String path to checkpoint.
248    input_signature: Nested TensorSpec containing the expected inputs. Can be
249      `None`, in which case the signature will be inferred from the model.
250
251  Raises:
252    ValueError: If the train/eval mode is being exported, but the model does
253      not have an optimizer.
254  """
255  compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
256  if compile_clone and not model.optimizer:
257    raise ValueError(
258        'Model does not have an optimizer. Cannot export mode %s' % mode)
259
260  model_graph = ops.get_default_graph()
261  with ops.Graph().as_default() as g, backend.learning_phase_scope(
262      mode == mode_keys.ModeKeys.TRAIN):
263
264    if input_signature is None:
265      input_tensors = None
266    else:
267      input_tensors = nest.map_structure(create_placeholder, input_signature)
268
269    # Clone the model into blank graph. This will create placeholders for inputs
270    # and targets.
271    clone = models_lib.clone_and_build_model(
272        model, input_tensors=input_tensors, custom_objects=custom_objects,
273        compile_clone=compile_clone)
274
275    # Make sure that iterations variable is added to the global step collection,
276    # to ensure that, when the SavedModel graph is loaded, the iterations
277    # variable is returned by `tf.compat.v1.train.get_global_step()`. This is
278    # required for compatibility with the SavedModelEstimator.
279    if compile_clone:
280      g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
281
282    # Extract update and train ops from train/test/predict functions.
283    train_op = None
284    if mode == mode_keys.ModeKeys.TRAIN:
285      clone._make_train_function()  # pylint: disable=protected-access
286      train_op = clone.train_function.updates_op
287    elif mode == mode_keys.ModeKeys.TEST:
288      clone._make_test_function()  # pylint: disable=protected-access
289    else:
290      clone._make_predict_function()  # pylint: disable=protected-access
291    g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
292
293    with session.Session().as_default():
294      clone_var_list = _get_var_list(clone)
295      if has_saved_vars:
296        # Confirm all variables in the clone have an entry in the checkpoint.
297        status = clone.load_weights(checkpoint_path)
298        status.assert_existing_objects_matched()
299      else:
300        # Confirm that variables between the clone and model match up exactly,
301        # not counting optimizer objects. Optimizer objects are ignored because
302        # if the model has not trained, the slot variables will not have been
303        # created yet.
304        # TODO(b/113179535): Replace with trackable equivalence.
305        _assert_same_non_optimizer_objects(model, model_graph, clone, g)
306
307        # TODO(b/113178242): Use value transfer for trackable objects.
308        clone.load_weights(checkpoint_path)
309
310        # Add graph and variables to SavedModel.
311        # TODO(b/113134168): Switch to add_meta_graph_and_variables.
312        clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
313        builder._has_saved_variables = True  # pylint: disable=protected-access
314
315      # Add graph to the SavedModel builder.
316      builder.add_meta_graph(
317          model_utils.EXPORT_TAG_MAP[mode],
318          signature_def_map=_create_signature_def_map(clone, mode),
319          saver=saver_lib.Saver(
320              clone_var_list,
321              # Allow saving Models with no variables. This is somewhat odd, but
322              # it's not necessarily a bug.
323              allow_empty=True),
324          init_op=variables.local_variables_initializer(),
325          train_op=train_op)
326    return None
327
328
329def _create_signature_def_map(model, mode):
330  """Creates a SignatureDef map from a Keras model."""
331  inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
332  if model.optimizer:
333    targets_dict = {x.name.split(':')[0]: x
334                    for x in model._targets if x is not None}  # pylint: disable=protected-access
335    inputs_dict.update(targets_dict)
336  outputs_dict = {name: x
337                  for name, x in zip(model.output_names, model.outputs)}
338  metrics = saving_utils.extract_model_metrics(model)
339
340  # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
341  # are by default not added to any collections. We are doing this here, so
342  # that metric variables get initialized.
343  local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
344  vars_to_add = set()
345  if metrics is not None:
346    for key, value in metrics.items():
347      if isinstance(value, metrics_lib.Metric):
348        vars_to_add.update(value.variables)
349        # Convert Metric instances to (value_tensor, update_op) tuple.
350        metrics[key] = (value.result(), value.updates[0])
351  # Remove variables that are in the local variables collection already.
352  vars_to_add = vars_to_add.difference(local_vars)
353  for v in vars_to_add:
354    ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
355
356  export_outputs = model_utils.export_outputs_for_mode(
357      mode,
358      predictions=outputs_dict,
359      loss=model.total_loss if model.optimizer else None,
360      metrics=metrics)
361  return model_utils.build_all_signature_defs(
362      inputs_dict,
363      export_outputs=export_outputs,
364      serving_only=(mode == mode_keys.ModeKeys.PREDICT))
365
366
367def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):  # pylint: disable=unused-argument
368  """Asserts model and clone contain the same trackable objects."""
369
370  # TODO(fchollet, kathywu): make sure this works in eager mode.
371  return True
372
373
374@keras_export(v1=['keras.experimental.load_from_saved_model'])
375def load_from_saved_model(saved_model_path, custom_objects=None):
376  """Loads a keras Model from a SavedModel created by `export_saved_model()`.
377
378  This function reinstantiates model state by:
379  1) loading model topology from json (this will eventually come
380     from metagraph).
381  2) loading model weights from checkpoint.
382
383  Example:
384
385  ```python
386  import tensorflow as tf
387
388  # Create a tf.keras model.
389  model = tf.keras.Sequential()
390  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
391  model.summary()
392
393  # Save the tf.keras model in the SavedModel format.
394  path = '/tmp/simple_keras_model'
395  tf.keras.experimental.export_saved_model(model, path)
396
397  # Load the saved keras model back.
398  new_model = tf.keras.experimental.load_from_saved_model(path)
399  new_model.summary()
400  ```
401
402  Args:
403    saved_model_path: a string specifying the path to an existing SavedModel.
404    custom_objects: Optional dictionary mapping names
405        (strings) to custom classes or functions to be
406        considered during deserialization.
407
408  Returns:
409    a keras.Model instance.
410  """
411  warnings.warn('`tf.keras.experimental.load_from_saved_model` is deprecated'
412                'and will be removed in a future version. '
413                'Please switch to `tf.keras.models.load_model`.')
414  # restore model topology from json string
415  model_json_filepath = os.path.join(
416      compat.as_bytes(saved_model_path),
417      compat.as_bytes(constants.ASSETS_DIRECTORY),
418      compat.as_bytes(SAVED_MODEL_FILENAME_JSON))
419  with gfile.Open(model_json_filepath, 'r') as f:
420    model_json = f.read()
421  model = model_config.model_from_json(
422      model_json, custom_objects=custom_objects)
423
424  # restore model weights
425  checkpoint_prefix = os.path.join(
426      compat.as_text(saved_model_path),
427      compat.as_text(constants.VARIABLES_DIRECTORY),
428      compat.as_text(constants.VARIABLES_FILENAME))
429  model.load_weights(checkpoint_prefix)
430  return model
431
432
433#### Directory / path helpers
434
435
436def _get_or_create_variables_dir(export_dir):
437  """Return variables sub-directory, or create one if it doesn't exist."""
438  variables_dir = _get_variables_dir(export_dir)
439  file_io.recursive_create_dir(variables_dir)
440  return variables_dir
441
442
443def _get_variables_dir(export_dir):
444  """Return variables sub-directory in the SavedModel."""
445  return os.path.join(
446      compat.as_text(export_dir),
447      compat.as_text(constants.VARIABLES_DIRECTORY))
448
449
450def _get_variables_path(export_dir):
451  """Return the variables path, used as the prefix for checkpoint files."""
452  return os.path.join(
453      compat.as_text(_get_variables_dir(export_dir)),
454      compat.as_text(constants.VARIABLES_FILENAME))
455
456
457def _get_or_create_assets_dir(export_dir):
458  """Return assets sub-directory, or create one if it doesn't exist."""
459  assets_destination_dir = _get_assets_dir(export_dir)
460
461  file_io.recursive_create_dir(assets_destination_dir)
462
463  return assets_destination_dir
464
465
466def _get_assets_dir(export_dir):
467  """Return path to asset directory in the SavedModel."""
468  return os.path.join(
469      compat.as_text(export_dir),
470      compat.as_text(constants.ASSETS_DIRECTORY))
471