• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SignatureDef utility functions implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.saved_model import signature_constants
28from tensorflow.python.saved_model import utils
29from tensorflow.python.util.tf_export import tf_export
30
31
32@tf_export('saved_model.signature_def_utils.build_signature_def')
33def build_signature_def(inputs=None, outputs=None, method_name=None):
34  """Utility function to build a SignatureDef protocol buffer.
35
36  Args:
37    inputs: Inputs of the SignatureDef defined as a proto map of string to
38        tensor info.
39    outputs: Outputs of the SignatureDef defined as a proto map of string to
40        tensor info.
41    method_name: Method name of the SignatureDef as a string.
42
43  Returns:
44    A SignatureDef protocol buffer constructed based on the supplied arguments.
45  """
46  signature_def = meta_graph_pb2.SignatureDef()
47  if inputs is not None:
48    for item in inputs:
49      signature_def.inputs[item].CopyFrom(inputs[item])
50  if outputs is not None:
51    for item in outputs:
52      signature_def.outputs[item].CopyFrom(outputs[item])
53  if method_name is not None:
54    signature_def.method_name = method_name
55  return signature_def
56
57
58@tf_export('saved_model.signature_def_utils.regression_signature_def')
59def regression_signature_def(examples, predictions):
60  """Creates regression signature from given examples and predictions.
61
62  This function produces signatures intended for use with the TensorFlow Serving
63  Regress API (tensorflow_serving/apis/prediction_service.proto), and so
64  constrains the input and output types to those allowed by TensorFlow Serving.
65
66  Args:
67    examples: A string `Tensor`, expected to accept serialized tf.Examples.
68    predictions: A float `Tensor`.
69
70  Returns:
71    A regression-flavored signature_def.
72
73  Raises:
74    ValueError: If examples is `None`.
75  """
76  if examples is None:
77    raise ValueError('Regression examples cannot be None.')
78  if not isinstance(examples, ops.Tensor):
79    raise ValueError('Regression examples must be a string Tensor.')
80  if predictions is None:
81    raise ValueError('Regression predictions cannot be None.')
82
83  input_tensor_info = utils.build_tensor_info(examples)
84  if input_tensor_info.dtype != types_pb2.DT_STRING:
85    raise ValueError('Regression examples must be a string Tensor.')
86  signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
87
88  output_tensor_info = utils.build_tensor_info(predictions)
89  if output_tensor_info.dtype != types_pb2.DT_FLOAT:
90    raise ValueError('Regression output must be a float Tensor.')
91  signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
92
93  signature_def = build_signature_def(
94      signature_inputs, signature_outputs,
95      signature_constants.REGRESS_METHOD_NAME)
96
97  return signature_def
98
99
100@tf_export('saved_model.signature_def_utils.classification_signature_def')
101def classification_signature_def(examples, classes, scores):
102  """Creates classification signature from given examples and predictions.
103
104  This function produces signatures intended for use with the TensorFlow Serving
105  Classify API (tensorflow_serving/apis/prediction_service.proto), and so
106  constrains the input and output types to those allowed by TensorFlow Serving.
107
108  Args:
109    examples: A string `Tensor`, expected to accept serialized tf.Examples.
110    classes: A string `Tensor`.  Note that the ClassificationResponse message
111      requires that class labels are strings, not integers or anything else.
112    scores: a float `Tensor`.
113
114  Returns:
115    A classification-flavored signature_def.
116
117  Raises:
118    ValueError: If examples is `None`.
119  """
120  if examples is None:
121    raise ValueError('Classification examples cannot be None.')
122  if not isinstance(examples, ops.Tensor):
123    raise ValueError('Classification examples must be a string Tensor.')
124  if classes is None and scores is None:
125    raise ValueError('Classification classes and scores cannot both be None.')
126
127  input_tensor_info = utils.build_tensor_info(examples)
128  if input_tensor_info.dtype != types_pb2.DT_STRING:
129    raise ValueError('Classification examples must be a string Tensor.')
130  signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
131
132  signature_outputs = {}
133  if classes is not None:
134    classes_tensor_info = utils.build_tensor_info(classes)
135    if classes_tensor_info.dtype != types_pb2.DT_STRING:
136      raise ValueError('Classification classes must be a string Tensor.')
137    signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
138        classes_tensor_info)
139  if scores is not None:
140    scores_tensor_info = utils.build_tensor_info(scores)
141    if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
142      raise ValueError('Classification scores must be a float Tensor.')
143    signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
144        scores_tensor_info)
145
146  signature_def = build_signature_def(
147      signature_inputs, signature_outputs,
148      signature_constants.CLASSIFY_METHOD_NAME)
149
150  return signature_def
151
152
153@tf_export('saved_model.signature_def_utils.predict_signature_def')
154def predict_signature_def(inputs, outputs):
155  """Creates prediction signature from given inputs and outputs.
156
157  This function produces signatures intended for use with the TensorFlow Serving
158  Predict API (tensorflow_serving/apis/prediction_service.proto). This API
159  imposes no constraints on the input and output types.
160
161  Args:
162    inputs: dict of string to `Tensor`.
163    outputs: dict of string to `Tensor`.
164
165  Returns:
166    A prediction-flavored signature_def.
167
168  Raises:
169    ValueError: If inputs or outputs is `None`.
170  """
171  if inputs is None or not inputs:
172    raise ValueError('Prediction inputs cannot be None or empty.')
173  if outputs is None or not outputs:
174    raise ValueError('Prediction outputs cannot be None or empty.')
175
176  signature_inputs = {key: utils.build_tensor_info(tensor)
177                      for key, tensor in inputs.items()}
178  signature_outputs = {key: utils.build_tensor_info(tensor)
179                       for key, tensor in outputs.items()}
180
181  signature_def = build_signature_def(
182      signature_inputs, signature_outputs,
183      signature_constants.PREDICT_METHOD_NAME)
184
185  return signature_def
186
187
188@tf_export('saved_model.signature_def_utils.is_valid_signature')
189def is_valid_signature(signature_def):
190  """Determine whether a SignatureDef can be served by TensorFlow Serving."""
191  if signature_def is None:
192    return False
193  return (_is_valid_classification_signature(signature_def) or
194          _is_valid_regression_signature(signature_def) or
195          _is_valid_predict_signature(signature_def))
196
197
198def _is_valid_predict_signature(signature_def):
199  """Determine whether the argument is a servable 'predict' SignatureDef."""
200  if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME:
201    return False
202  if not signature_def.inputs.keys():
203    return False
204  if not signature_def.outputs.keys():
205    return False
206  return True
207
208
209def _is_valid_regression_signature(signature_def):
210  """Determine whether the argument is a servable 'regress' SignatureDef."""
211  if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME:
212    return False
213
214  if (set(signature_def.inputs.keys())
215      != set([signature_constants.REGRESS_INPUTS])):
216    return False
217  if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype !=
218      types_pb2.DT_STRING):
219    return False
220
221  if (set(signature_def.outputs.keys())
222      != set([signature_constants.REGRESS_OUTPUTS])):
223    return False
224  if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype !=
225      types_pb2.DT_FLOAT):
226    return False
227
228  return True
229
230
231def _is_valid_classification_signature(signature_def):
232  """Determine whether the argument is a servable 'classify' SignatureDef."""
233  if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME:
234    return False
235
236  if (set(signature_def.inputs.keys())
237      != set([signature_constants.CLASSIFY_INPUTS])):
238    return False
239  if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype !=
240      types_pb2.DT_STRING):
241    return False
242
243  allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES,
244                         signature_constants.CLASSIFY_OUTPUT_SCORES])
245
246  if not signature_def.outputs.keys():
247    return False
248  if set(signature_def.outputs.keys()) - allowed_outputs:
249    return False
250  if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs
251      and
252      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype
253      != types_pb2.DT_STRING):
254    return False
255  if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs
256      and
257      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype !=
258      types_pb2.DT_FLOAT):
259    return False
260
261  return True
262
263
264def _get_shapes_from_tensor_info_dict(tensor_info_dict):
265  """Returns a map of keys to TensorShape objects.
266
267  Args:
268    tensor_info_dict: map with TensorInfo proto as values.
269
270  Returns:
271    Map with corresponding TensorShape objects as values.
272  """
273  return {
274      key: tensor_shape.TensorShape(tensor_info.tensor_shape)
275      for key, tensor_info in tensor_info_dict.items()
276  }
277
278
279def _get_types_from_tensor_info_dict(tensor_info_dict):
280  """Returns a map of keys to DType objects.
281
282  Args:
283    tensor_info_dict: map with TensorInfo proto as values.
284
285  Returns:
286    Map with corresponding DType objects as values.
287  """
288  return {
289      key: dtypes.DType(tensor_info.dtype)
290      for key, tensor_info in tensor_info_dict.items()
291  }
292
293
294def get_signature_def_input_shapes(signature):
295  """Returns map of parameter names to their shapes.
296
297  Args:
298    signature: SignatureDef proto.
299
300  Returns:
301    Map from string to TensorShape objects.
302  """
303  return _get_shapes_from_tensor_info_dict(signature.inputs)
304
305
306def get_signature_def_input_types(signature):
307  """Returns map of output names to their types.
308
309  Args:
310    signature: SignatureDef proto.
311
312  Returns:
313    Map from string to DType objects.
314  """
315  return _get_types_from_tensor_info_dict(signature.inputs)
316
317
318def get_signature_def_output_shapes(signature):
319  """Returns map of output names to their shapes.
320
321  Args:
322    signature: SignatureDef proto.
323
324  Returns:
325    Map from string to TensorShape objects.
326  """
327  return _get_shapes_from_tensor_info_dict(signature.outputs)
328
329
330def get_signature_def_output_types(signature):
331  """Returns map of output names to their types.
332
333  Args:
334    signature: SignatureDef proto.
335
336  Returns:
337    Map from string to DType objects.
338  """
339  return _get_types_from_tensor_info_dict(signature.outputs)
340