• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SignatureDef utility functions implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.saved_model import signature_constants
27from tensorflow.python.saved_model import utils_impl as utils
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
30
31
32@tf_export(
33    v1=[
34        'saved_model.build_signature_def',
35        'saved_model.signature_def_utils.build_signature_def'
36    ])
37@deprecation.deprecated_endpoints(
38    'saved_model.signature_def_utils.build_signature_def')
39def build_signature_def(inputs=None, outputs=None, method_name=None):
40  """Utility function to build a SignatureDef protocol buffer.
41
42  Args:
43    inputs: Inputs of the SignatureDef defined as a proto map of string to
44        tensor info.
45    outputs: Outputs of the SignatureDef defined as a proto map of string to
46        tensor info.
47    method_name: Method name of the SignatureDef as a string.
48
49  Returns:
50    A SignatureDef protocol buffer constructed based on the supplied arguments.
51  """
52  signature_def = meta_graph_pb2.SignatureDef()
53  if inputs is not None:
54    for item in inputs:
55      signature_def.inputs[item].CopyFrom(inputs[item])
56  if outputs is not None:
57    for item in outputs:
58      signature_def.outputs[item].CopyFrom(outputs[item])
59  if method_name is not None:
60    signature_def.method_name = method_name
61  return signature_def
62
63
64@tf_export(
65    v1=[
66        'saved_model.regression_signature_def',
67        'saved_model.signature_def_utils.regression_signature_def'
68    ])
69@deprecation.deprecated_endpoints(
70    'saved_model.signature_def_utils.regression_signature_def')
71def regression_signature_def(examples, predictions):
72  """Creates regression signature from given examples and predictions.
73
74  This function produces signatures intended for use with the TensorFlow Serving
75  Regress API (tensorflow_serving/apis/prediction_service.proto), and so
76  constrains the input and output types to those allowed by TensorFlow Serving.
77
78  Args:
79    examples: A string `Tensor`, expected to accept serialized tf.Examples.
80    predictions: A float `Tensor`.
81
82  Returns:
83    A regression-flavored signature_def.
84
85  Raises:
86    ValueError: If examples is `None`.
87  """
88  if examples is None:
89    raise ValueError('Regression examples cannot be None.')
90  if not isinstance(examples, ops.Tensor):
91    raise ValueError('Regression examples must be a string Tensor.')
92  if predictions is None:
93    raise ValueError('Regression predictions cannot be None.')
94
95  input_tensor_info = utils.build_tensor_info(examples)
96  if input_tensor_info.dtype != types_pb2.DT_STRING:
97    raise ValueError('Regression examples must be a string Tensor.')
98  signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
99
100  output_tensor_info = utils.build_tensor_info(predictions)
101  if output_tensor_info.dtype != types_pb2.DT_FLOAT:
102    raise ValueError('Regression output must be a float Tensor.')
103  signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
104
105  signature_def = build_signature_def(
106      signature_inputs, signature_outputs,
107      signature_constants.REGRESS_METHOD_NAME)
108
109  return signature_def
110
111
112@tf_export(
113    v1=[
114        'saved_model.classification_signature_def',
115        'saved_model.signature_def_utils.classification_signature_def'
116    ])
117@deprecation.deprecated_endpoints(
118    'saved_model.signature_def_utils.classification_signature_def')
119def classification_signature_def(examples, classes, scores):
120  """Creates classification signature from given examples and predictions.
121
122  This function produces signatures intended for use with the TensorFlow Serving
123  Classify API (tensorflow_serving/apis/prediction_service.proto), and so
124  constrains the input and output types to those allowed by TensorFlow Serving.
125
126  Args:
127    examples: A string `Tensor`, expected to accept serialized tf.Examples.
128    classes: A string `Tensor`.  Note that the ClassificationResponse message
129      requires that class labels are strings, not integers or anything else.
130    scores: a float `Tensor`.
131
132  Returns:
133    A classification-flavored signature_def.
134
135  Raises:
136    ValueError: If examples is `None`.
137  """
138  if examples is None:
139    raise ValueError('Classification examples cannot be None.')
140  if not isinstance(examples, ops.Tensor):
141    raise ValueError('Classification examples must be a string Tensor.')
142  if classes is None and scores is None:
143    raise ValueError('Classification classes and scores cannot both be None.')
144
145  input_tensor_info = utils.build_tensor_info(examples)
146  if input_tensor_info.dtype != types_pb2.DT_STRING:
147    raise ValueError('Classification examples must be a string Tensor.')
148  signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
149
150  signature_outputs = {}
151  if classes is not None:
152    classes_tensor_info = utils.build_tensor_info(classes)
153    if classes_tensor_info.dtype != types_pb2.DT_STRING:
154      raise ValueError('Classification classes must be a string Tensor.')
155    signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
156        classes_tensor_info)
157  if scores is not None:
158    scores_tensor_info = utils.build_tensor_info(scores)
159    if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
160      raise ValueError('Classification scores must be a float Tensor.')
161    signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
162        scores_tensor_info)
163
164  signature_def = build_signature_def(
165      signature_inputs, signature_outputs,
166      signature_constants.CLASSIFY_METHOD_NAME)
167
168  return signature_def
169
170
171@tf_export(
172    v1=[
173        'saved_model.predict_signature_def',
174        'saved_model.signature_def_utils.predict_signature_def'
175    ])
176@deprecation.deprecated_endpoints(
177    'saved_model.signature_def_utils.predict_signature_def')
178def predict_signature_def(inputs, outputs):
179  """Creates prediction signature from given inputs and outputs.
180
181  This function produces signatures intended for use with the TensorFlow Serving
182  Predict API (tensorflow_serving/apis/prediction_service.proto). This API
183  imposes no constraints on the input and output types.
184
185  Args:
186    inputs: dict of string to `Tensor`.
187    outputs: dict of string to `Tensor`.
188
189  Returns:
190    A prediction-flavored signature_def.
191
192  Raises:
193    ValueError: If inputs or outputs is `None`.
194  """
195  if inputs is None or not inputs:
196    raise ValueError('Prediction inputs cannot be None or empty.')
197  if outputs is None or not outputs:
198    raise ValueError('Prediction outputs cannot be None or empty.')
199
200  signature_inputs = {key: utils.build_tensor_info(tensor)
201                      for key, tensor in inputs.items()}
202  signature_outputs = {key: utils.build_tensor_info(tensor)
203                       for key, tensor in outputs.items()}
204
205  signature_def = build_signature_def(
206      signature_inputs, signature_outputs,
207      signature_constants.PREDICT_METHOD_NAME)
208
209  return signature_def
210
211
212# LINT.IfChange
213def supervised_train_signature_def(
214    inputs, loss, predictions=None, metrics=None):
215  return _supervised_signature_def(
216      signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
217      predictions=predictions, metrics=metrics)
218
219
220def supervised_eval_signature_def(
221    inputs, loss, predictions=None, metrics=None):
222  return _supervised_signature_def(
223      signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
224      predictions=predictions, metrics=metrics)
225
226
227def _supervised_signature_def(
228    method_name, inputs, loss=None, predictions=None,
229    metrics=None):
230  """Creates a signature for training and eval data.
231
232  This function produces signatures that describe the inputs and outputs
233  of a supervised process, such as training or evaluation, that
234  results in loss, metrics, and the like. Note that this function only requires
235  inputs to be not None.
236
237  Args:
238    method_name: Method name of the SignatureDef as a string.
239    inputs: dict of string to `Tensor`.
240    loss: dict of string to `Tensor` representing computed loss.
241    predictions: dict of string to `Tensor` representing the output predictions.
242    metrics: dict of string to `Tensor` representing metric ops.
243
244  Returns:
245    A train- or eval-flavored signature_def.
246
247  Raises:
248    ValueError: If inputs or outputs is `None`.
249  """
250  if inputs is None or not inputs:
251    raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
252
253  signature_inputs = {key: utils.build_tensor_info(tensor)
254                      for key, tensor in inputs.items()}
255
256  signature_outputs = {}
257  for output_set in (loss, predictions, metrics):
258    if output_set is not None:
259      sig_out = {key: utils.build_tensor_info(tensor)
260                 for key, tensor in output_set.items()}
261      signature_outputs.update(sig_out)
262
263  signature_def = build_signature_def(
264      signature_inputs, signature_outputs, method_name)
265
266  return signature_def
267# LINT.ThenChange(//tensorflow/python/keras/saving/utils_v1/signature_def_utils.py)
268
269
270@tf_export(
271    v1=[
272        'saved_model.is_valid_signature',
273        'saved_model.signature_def_utils.is_valid_signature'
274    ])
275@deprecation.deprecated_endpoints(
276    'saved_model.signature_def_utils.is_valid_signature')
277def is_valid_signature(signature_def):
278  """Determine whether a SignatureDef can be served by TensorFlow Serving."""
279  if signature_def is None:
280    return False
281  return (_is_valid_classification_signature(signature_def) or
282          _is_valid_regression_signature(signature_def) or
283          _is_valid_predict_signature(signature_def))
284
285
286def _is_valid_predict_signature(signature_def):
287  """Determine whether the argument is a servable 'predict' SignatureDef."""
288  if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME:
289    return False
290  if not signature_def.inputs.keys():
291    return False
292  if not signature_def.outputs.keys():
293    return False
294  return True
295
296
297def _is_valid_regression_signature(signature_def):
298  """Determine whether the argument is a servable 'regress' SignatureDef."""
299  if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME:
300    return False
301
302  if (set(signature_def.inputs.keys())
303      != set([signature_constants.REGRESS_INPUTS])):
304    return False
305  if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype !=
306      types_pb2.DT_STRING):
307    return False
308
309  if (set(signature_def.outputs.keys())
310      != set([signature_constants.REGRESS_OUTPUTS])):
311    return False
312  if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype !=
313      types_pb2.DT_FLOAT):
314    return False
315
316  return True
317
318
319def _is_valid_classification_signature(signature_def):
320  """Determine whether the argument is a servable 'classify' SignatureDef."""
321  if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME:
322    return False
323
324  if (set(signature_def.inputs.keys())
325      != set([signature_constants.CLASSIFY_INPUTS])):
326    return False
327  if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype !=
328      types_pb2.DT_STRING):
329    return False
330
331  allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES,
332                         signature_constants.CLASSIFY_OUTPUT_SCORES])
333
334  if not signature_def.outputs.keys():
335    return False
336  if set(signature_def.outputs.keys()) - allowed_outputs:
337    return False
338  if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs
339      and
340      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype
341      != types_pb2.DT_STRING):
342    return False
343  if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs
344      and
345      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype !=
346      types_pb2.DT_FLOAT):
347    return False
348
349  return True
350
351
352def op_signature_def(op, key):
353  """Creates a signature def with the output pointing to an op.
354
355  Note that op isn't strictly enforced to be an Op object, and may be a Tensor.
356  It is recommended to use the build_signature_def() function for Tensors.
357
358  Args:
359    op: An Op (or possibly Tensor).
360    key: Key to graph element in the SignatureDef outputs.
361
362  Returns:
363    A SignatureDef with a single output pointing to the op.
364  """
365  # Use build_tensor_info_from_op, which creates a TensorInfo from the element's
366  # name.
367  return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)})
368
369
370def load_op_from_signature_def(signature_def, key, import_scope=None):
371  """Load an Op from a SignatureDef created by op_signature_def().
372
373  Args:
374    signature_def: a SignatureDef proto
375    key: string key to op in the SignatureDef outputs.
376    import_scope: Scope used to import the op
377
378  Returns:
379    Op (or possibly Tensor) in the graph with the same name as saved in the
380      SignatureDef.
381
382  Raises:
383    NotFoundError: If the op could not be found in the graph.
384  """
385  tensor_info = signature_def.outputs[key]
386  try:
387    # The init and train ops are not strictly enforced to be operations, so
388    # retrieve any graph element (can be either op or tensor).
389    return utils.get_element_from_tensor_info(
390        tensor_info, import_scope=import_scope)
391  except KeyError:
392    raise errors.NotFoundError(
393        None, None,
394        'The {0} could not be found in the graph. Please make sure the '
395        'SavedModel was created by the internal _SavedModelBuilder. If you '
396        'are using the public API, please make sure the SignatureDef in the '
397        'SavedModel does not contain the key "{0}".'.format(key))
398