• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for creating SavedModels."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import time
24
25import six
26
27from tensorflow.python.platform import gfile
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.saved_model import signature_constants
30from tensorflow.python.saved_model import signature_def_utils
31from tensorflow.python.saved_model import tag_constants
32from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
33from tensorflow.python.saved_model.model_utils import mode_keys
34from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys
35from tensorflow.python.util import compat
36
37
38# Mapping of the modes to appropriate MetaGraph tags in the SavedModel.
39EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{
40    ModeKeys.PREDICT: [tag_constants.SERVING],
41    ModeKeys.TRAIN: [tag_constants.TRAINING],
42    ModeKeys.TEST: [tag_constants.EVAL]})
43
44# For every exported mode, a SignatureDef map should be created using the
45# functions `export_outputs_for_mode` and `build_all_signature_defs`. By
46# default, this map will contain a single Signature that defines the input
47# tensors and output predictions, losses, and/or metrics (depending on the mode)
48# The default keys used in the SignatureDef map are defined below.
49SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{
50    ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
51    ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
52    ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY})
53
54# Default names used in the SignatureDef input map, which maps strings to
55# TensorInfo protos.
56SINGLE_FEATURE_DEFAULT_NAME = 'feature'
57SINGLE_RECEIVER_DEFAULT_NAME = 'input'
58SINGLE_LABEL_DEFAULT_NAME = 'label'
59
60### Below utilities are specific to SavedModel exports.
61
62
63def build_all_signature_defs(receiver_tensors,
64                             export_outputs,
65                             receiver_tensors_alternatives=None,
66                             serving_only=True):
67  """Build `SignatureDef`s for all export outputs.
68
69  Args:
70    receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
71      input nodes where this receiver expects to be fed by default.  Typically,
72      this is a single placeholder expecting serialized `tf.Example` protos.
73    export_outputs: a dict of ExportOutput instances, each of which has
74      an as_signature_def instance method that will be called to retrieve
75      the signature_def for all export output tensors.
76    receiver_tensors_alternatives: a dict of string to additional
77      groups of receiver tensors, each of which may be a `Tensor` or a dict of
78      string to `Tensor`.  These named receiver tensor alternatives generate
79      additional serving signatures, which may be used to feed inputs at
80      different points within the input receiver subgraph.  A typical usage is
81      to allow feeding raw feature `Tensor`s *downstream* of the
82      tf.parse_example() op.  Defaults to None.
83    serving_only: boolean; if true, resulting signature defs will only include
84      valid serving signatures. If false, all requested signatures will be
85      returned.
86
87  Returns:
88    signature_def representing all passed args.
89
90  Raises:
91    ValueError: if export_outputs is not a dict
92  """
93  if not isinstance(receiver_tensors, dict):
94    receiver_tensors = {SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
95  if export_outputs is None or not isinstance(export_outputs, dict):
96    raise ValueError('export_outputs must be a dict and not'
97                     '{}'.format(type(export_outputs)))
98
99  signature_def_map = {}
100  excluded_signatures = {}
101  for output_key, export_output in export_outputs.items():
102    signature_name = '{}'.format(output_key or 'None')
103    try:
104      signature = export_output.as_signature_def(receiver_tensors)
105      signature_def_map[signature_name] = signature
106    except ValueError as e:
107      excluded_signatures[signature_name] = str(e)
108
109  if receiver_tensors_alternatives:
110    for receiver_name, receiver_tensors_alt in (
111        six.iteritems(receiver_tensors_alternatives)):
112      if not isinstance(receiver_tensors_alt, dict):
113        receiver_tensors_alt = {
114            SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
115        }
116      for output_key, export_output in export_outputs.items():
117        signature_name = '{}:{}'.format(receiver_name or 'None', output_key or
118                                        'None')
119        try:
120          signature = export_output.as_signature_def(receiver_tensors_alt)
121          signature_def_map[signature_name] = signature
122        except ValueError as e:
123          excluded_signatures[signature_name] = str(e)
124
125  _log_signature_report(signature_def_map, excluded_signatures)
126
127  # The above calls to export_output_lib.as_signature_def should return only
128  # valid signatures; if there is a validity problem, they raise a ValueError,
129  # in which case we exclude that signature from signature_def_map above.
130  # The is_valid_signature check ensures that the signatures produced are
131  # valid for serving, and acts as an additional sanity check for export
132  # signatures produced for serving. We skip this check for training and eval
133  # signatures, which are not intended for serving.
134  if serving_only:
135    signature_def_map = {
136        k: v
137        for k, v in signature_def_map.items()
138        if signature_def_utils.is_valid_signature(v)
139    }
140  return signature_def_map
141
142
143_FRIENDLY_METHOD_NAMES = {
144    signature_constants.CLASSIFY_METHOD_NAME: 'Classify',
145    signature_constants.REGRESS_METHOD_NAME: 'Regress',
146    signature_constants.PREDICT_METHOD_NAME: 'Predict',
147    signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train',
148    signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval',
149}
150
151
152def _log_signature_report(signature_def_map, excluded_signatures):
153  """Log a report of which signatures were produced."""
154  sig_names_by_method_name = collections.defaultdict(list)
155
156  # We'll collect whatever method_names are present, but also we want to make
157  # sure to output a line for each of the three standard methods even if they
158  # have no signatures.
159  for method_name in _FRIENDLY_METHOD_NAMES:
160    sig_names_by_method_name[method_name] = []
161
162  for signature_name, sig in signature_def_map.items():
163    sig_names_by_method_name[sig.method_name].append(signature_name)
164
165  # TODO(b/67733540): consider printing the full signatures, not just names
166  for method_name, sig_names in sig_names_by_method_name.items():
167    if method_name in _FRIENDLY_METHOD_NAMES:
168      method_name = _FRIENDLY_METHOD_NAMES[method_name]
169    logging.info('Signatures INCLUDED in export for {}: {}'.format(
170        method_name, sig_names if sig_names else 'None'))
171
172  if excluded_signatures:
173    logging.info('Signatures EXCLUDED from export because they cannot be '
174                 'be served via TensorFlow Serving APIs:')
175    for signature_name, message in excluded_signatures.items():
176      logging.info('\'{}\' : {}'.format(signature_name, message))
177
178  if not signature_def_map:
179    logging.warn('Export includes no signatures!')
180  elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
181        signature_def_map):
182    logging.warn('Export includes no default signature!')
183
184
185# When we create a timestamped directory, there is a small chance that the
186# directory already exists because another process is also creating these
187# directories. In this case we just wait one second to get a new timestamp and
188# try again. If this fails several times in a row, then something is seriously
189# wrong.
190MAX_DIRECTORY_CREATION_ATTEMPTS = 10
191
192
193def get_timestamped_export_dir(export_dir_base):
194  """Builds a path to a new subdirectory within the base directory.
195
196  Each export is written into a new subdirectory named using the
197  current time.  This guarantees monotonically increasing version
198  numbers even across multiple runs of the pipeline.
199  The timestamp used is the number of seconds since epoch UTC.
200
201  Args:
202    export_dir_base: A string containing a directory to write the exported
203        graph and checkpoints.
204  Returns:
205    The full path of the new subdirectory (which is not actually created yet).
206
207  Raises:
208    RuntimeError: if repeated attempts fail to obtain a unique timestamped
209      directory name.
210  """
211  attempts = 0
212  while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
213    timestamp = int(time.time())
214
215    result_dir = os.path.join(
216        compat.as_bytes(export_dir_base), compat.as_bytes(str(timestamp)))
217    if not gfile.Exists(result_dir):
218      # Collisions are still possible (though extremely unlikely): this
219      # directory is not actually created yet, but it will be almost
220      # instantly on return from this function.
221      return result_dir
222    time.sleep(1)
223    attempts += 1
224    logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format(
225        result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
226  raise RuntimeError('Failed to obtain a unique export directory name after '
227                     '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
228
229
230def get_temp_export_dir(timestamped_export_dir):
231  """Builds a directory name based on the argument but starting with 'temp-'.
232
233  This relies on the fact that TensorFlow Serving ignores subdirectories of
234  the base directory that can't be parsed as integers.
235
236  Args:
237    timestamped_export_dir: the name of the eventual export directory, e.g.
238      /foo/bar/<timestamp>
239
240  Returns:
241    A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>.
242  """
243  (dirname, basename) = os.path.split(timestamped_export_dir)
244  temp_export_dir = os.path.join(
245      compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename)))
246  return temp_export_dir
247
248
249def export_outputs_for_mode(
250    mode, serving_export_outputs=None, predictions=None, loss=None,
251    metrics=None):
252  """Util function for constructing a `ExportOutput` dict given a mode.
253
254  The returned dict can be directly passed to `build_all_signature_defs` helper
255  function as the `export_outputs` argument, used for generating a SignatureDef
256  map.
257
258  Args:
259    mode: A `ModeKeys` specifying the mode.
260    serving_export_outputs: Describes the output signatures to be exported to
261      `SavedModel` and used during serving. Should be a dict or None.
262    predictions: A dict of Tensors or single Tensor representing model
263        predictions. This argument is only used if serving_export_outputs is not
264        set.
265    loss: A dict of Tensors or single Tensor representing calculated loss.
266    metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
267      metric_value must be a Tensor, and update_op must be a Tensor or Op
268
269  Returns:
270    Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
271    The key is the expected SignatureDef key for the mode.
272
273  Raises:
274    ValueError: if an appropriate ExportOutput cannot be found for the mode.
275  """
276  if mode not in SIGNATURE_KEY_MAP:
277    raise ValueError(
278        'Export output type not found for mode: {}. Expected one of: {}.\n'
279        'One likely error is that V1 Estimator Modekeys were somehow passed to '
280        'this function. Please ensure that you are using the new ModeKeys.'
281        .format(mode, SIGNATURE_KEY_MAP.keys()))
282  signature_key = SIGNATURE_KEY_MAP[mode]
283  if mode_keys.is_predict(mode):
284    return get_export_outputs(serving_export_outputs, predictions)
285  elif mode_keys.is_train(mode):
286    return {signature_key: export_output_lib.TrainOutput(
287        loss=loss, predictions=predictions, metrics=metrics)}
288  else:
289    return {signature_key: export_output_lib.EvalOutput(
290        loss=loss, predictions=predictions, metrics=metrics)}
291
292
293def get_export_outputs(export_outputs, predictions):
294  """Validate export_outputs or create default export_outputs.
295
296  Args:
297    export_outputs: Describes the output signatures to be exported to
298      `SavedModel` and used during serving. Should be a dict or None.
299    predictions:  Predictions `Tensor` or dict of `Tensor`.
300
301  Returns:
302    Valid export_outputs dict
303
304  Raises:
305    TypeError: if export_outputs is not a dict or its values are not
306      ExportOutput instances.
307  """
308  if export_outputs is None:
309    default_output = export_output_lib.PredictOutput(predictions)
310    export_outputs = {
311        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output}
312
313  if not isinstance(export_outputs, dict):
314    raise TypeError('export_outputs must be dict, given: {}'.format(
315        export_outputs))
316  for v in six.itervalues(export_outputs):
317    if not isinstance(v, export_output_lib.ExportOutput):
318      raise TypeError(
319          'Values in export_outputs must be ExportOutput objects. '
320          'Given: {}'.format(export_outputs))
321
322  _maybe_add_default_serving_output(export_outputs)
323
324  return export_outputs
325
326
327def _maybe_add_default_serving_output(export_outputs):
328  """Add a default serving output to the export_outputs if not present.
329
330  Args:
331    export_outputs: Describes the output signatures to be exported to
332      `SavedModel` and used during serving. Should be a dict.
333
334  Returns:
335    export_outputs dict with default serving signature added if necessary
336
337  Raises:
338    ValueError: if multiple export_outputs were provided without a default
339      serving key.
340  """
341  if len(export_outputs) == 1:
342    (key, value), = export_outputs.items()
343    if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
344      export_outputs[
345          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value
346  if len(export_outputs) > 1:
347    if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
348        not in export_outputs):
349      raise ValueError(
350          'Multiple export_outputs were provided, but none of them is '
351          'specified as the default.  Do this by naming one of them with '
352          'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.')
353
354  return export_outputs
355