• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SignatureDef utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import types_pb2
22from tensorflow.core.protobuf import meta_graph_pb2
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.ops import array_ops
26from tensorflow.python.platform import test
27from tensorflow.python.saved_model import signature_constants
28from tensorflow.python.saved_model import signature_def_utils_impl
29from tensorflow.python.saved_model import utils
30
31
32# We'll reuse the same tensor_infos in multiple contexts just for the tests.
33# The validator doesn't check shapes so we just omit them.
34_STRING = meta_graph_pb2.TensorInfo(
35    name="foobar",
36    dtype=dtypes.string.as_datatype_enum
37)
38
39
40_FLOAT = meta_graph_pb2.TensorInfo(
41    name="foobar",
42    dtype=dtypes.float32.as_datatype_enum
43)
44
45
46def _make_signature(inputs, outputs, name=None):
47  input_info = {
48      input_name: utils.build_tensor_info(tensor)
49      for input_name, tensor in inputs.items()
50  }
51  output_info = {
52      output_name: utils.build_tensor_info(tensor)
53      for output_name, tensor in outputs.items()
54  }
55  return signature_def_utils_impl.build_signature_def(input_info, output_info,
56                                                      name)
57
58
59class SignatureDefUtilsTest(test.TestCase):
60
61  def testBuildSignatureDef(self):
62    x = array_ops.placeholder(dtypes.float32, 1, name="x")
63    x_tensor_info = utils.build_tensor_info(x)
64    inputs = dict()
65    inputs["foo-input"] = x_tensor_info
66
67    y = array_ops.placeholder(dtypes.float32, name="y")
68    y_tensor_info = utils.build_tensor_info(y)
69    outputs = dict()
70    outputs["foo-output"] = y_tensor_info
71
72    signature_def = signature_def_utils_impl.build_signature_def(
73        inputs, outputs, "foo-method-name")
74    self.assertEqual("foo-method-name", signature_def.method_name)
75
76    # Check inputs in signature def.
77    self.assertEqual(1, len(signature_def.inputs))
78    x_tensor_info_actual = signature_def.inputs["foo-input"]
79    self.assertEqual("x:0", x_tensor_info_actual.name)
80    self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
81    self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
82    self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
83
84    # Check outputs in signature def.
85    self.assertEqual(1, len(signature_def.outputs))
86    y_tensor_info_actual = signature_def.outputs["foo-output"]
87    self.assertEqual("y:0", y_tensor_info_actual.name)
88    self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
89    self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
90
91  def testRegressionSignatureDef(self):
92    input1 = constant_op.constant("a", name="input-1")
93    output1 = constant_op.constant(2.2, name="output-1")
94    signature_def = signature_def_utils_impl.regression_signature_def(
95        input1, output1)
96
97    self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
98                     signature_def.method_name)
99
100    # Check inputs in signature def.
101    self.assertEqual(1, len(signature_def.inputs))
102    x_tensor_info_actual = (
103        signature_def.inputs[signature_constants.REGRESS_INPUTS])
104    self.assertEqual("input-1:0", x_tensor_info_actual.name)
105    self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
106    self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
107
108    # Check outputs in signature def.
109    self.assertEqual(1, len(signature_def.outputs))
110    y_tensor_info_actual = (
111        signature_def.outputs[signature_constants.REGRESS_OUTPUTS])
112    self.assertEqual("output-1:0", y_tensor_info_actual.name)
113    self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
114    self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
115
116  def testClassificationSignatureDef(self):
117    input1 = constant_op.constant("a", name="input-1")
118    output1 = constant_op.constant("b", name="output-1")
119    output2 = constant_op.constant(3.3, name="output-2")
120    signature_def = signature_def_utils_impl.classification_signature_def(
121        input1, output1, output2)
122
123    self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
124                     signature_def.method_name)
125
126    # Check inputs in signature def.
127    self.assertEqual(1, len(signature_def.inputs))
128    x_tensor_info_actual = (
129        signature_def.inputs[signature_constants.CLASSIFY_INPUTS])
130    self.assertEqual("input-1:0", x_tensor_info_actual.name)
131    self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
132    self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
133
134    # Check outputs in signature def.
135    self.assertEqual(2, len(signature_def.outputs))
136    classes_tensor_info_actual = (
137        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES])
138    self.assertEqual("output-1:0", classes_tensor_info_actual.name)
139    self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype)
140    self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim))
141    scores_tensor_info_actual = (
142        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES])
143    self.assertEqual("output-2:0", scores_tensor_info_actual.name)
144    self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype)
145    self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim))
146
147  def testPredictionSignatureDef(self):
148    input1 = constant_op.constant("a", name="input-1")
149    input2 = constant_op.constant("b", name="input-2")
150    output1 = constant_op.constant("c", name="output-1")
151    output2 = constant_op.constant("d", name="output-2")
152    signature_def = signature_def_utils_impl.predict_signature_def({
153        "input-1": input1,
154        "input-2": input2
155    }, {"output-1": output1,
156        "output-2": output2})
157
158    self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
159                     signature_def.method_name)
160
161    # Check inputs in signature def.
162    self.assertEqual(2, len(signature_def.inputs))
163    input1_tensor_info_actual = (signature_def.inputs["input-1"])
164    self.assertEqual("input-1:0", input1_tensor_info_actual.name)
165    self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
166    self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
167    input2_tensor_info_actual = (signature_def.inputs["input-2"])
168    self.assertEqual("input-2:0", input2_tensor_info_actual.name)
169    self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
170    self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
171
172    # Check outputs in signature def.
173    self.assertEqual(2, len(signature_def.outputs))
174    output1_tensor_info_actual = (signature_def.outputs["output-1"])
175    self.assertEqual("output-1:0", output1_tensor_info_actual.name)
176    self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype)
177    self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim))
178    output2_tensor_info_actual = (signature_def.outputs["output-2"])
179    self.assertEqual("output-2:0", output2_tensor_info_actual.name)
180    self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
181    self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
182
183  def testGetShapeAndTypes(self):
184    inputs = {
185        "input-1": constant_op.constant(["a", "b"]),
186        "input-2": array_ops.placeholder(dtypes.float32, [10, 11]),
187    }
188    outputs = {
189        "output-1": array_ops.placeholder(dtypes.float32, [10, 32]),
190        "output-2": constant_op.constant([["b"]]),
191    }
192    signature_def = _make_signature(inputs, outputs)
193    self.assertEqual(
194        signature_def_utils_impl.get_signature_def_input_shapes(signature_def),
195        {"input-1": [2], "input-2": [10, 11]})
196    self.assertEqual(
197        signature_def_utils_impl.get_signature_def_output_shapes(signature_def),
198        {"output-1": [10, 32], "output-2": [1, 1]})
199    self.assertEqual(
200        signature_def_utils_impl.get_signature_def_input_types(signature_def),
201        {"input-1": dtypes.string, "input-2": dtypes.float32})
202    self.assertEqual(
203        signature_def_utils_impl.get_signature_def_output_types(signature_def),
204        {"output-1": dtypes.float32, "output-2": dtypes.string})
205
206  def testGetNonFullySpecifiedShapes(self):
207    outputs = {
208        "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]),
209        "output-2": array_ops.sparse_placeholder(dtypes.float32),
210    }
211    signature_def = _make_signature({}, outputs)
212    shapes = signature_def_utils_impl.get_signature_def_output_shapes(
213        signature_def)
214    self.assertEqual(len(shapes), 2)
215    # Must compare shapes with as_list() since 2 equivalent non-fully defined
216    # shapes are not equal to each other.
217    self.assertEqual(shapes["output-1"].as_list(), [None, 10, None])
218    # Must compare `dims` since its an unknown shape.
219    self.assertEqual(shapes["output-2"].dims, None)
220
221  def _assertValidSignature(self, inputs, outputs, method_name):
222    signature_def = signature_def_utils_impl.build_signature_def(
223        inputs, outputs, method_name)
224    self.assertTrue(
225        signature_def_utils_impl.is_valid_signature(signature_def))
226
227  def _assertInvalidSignature(self, inputs, outputs, method_name):
228    signature_def = signature_def_utils_impl.build_signature_def(
229        inputs, outputs, method_name)
230    self.assertFalse(
231        signature_def_utils_impl.is_valid_signature(signature_def))
232
233  def testValidSignaturesAreAccepted(self):
234    self._assertValidSignature(
235        {"inputs": _STRING},
236        {"classes": _STRING, "scores": _FLOAT},
237        signature_constants.CLASSIFY_METHOD_NAME)
238
239    self._assertValidSignature(
240        {"inputs": _STRING},
241        {"classes": _STRING},
242        signature_constants.CLASSIFY_METHOD_NAME)
243
244    self._assertValidSignature(
245        {"inputs": _STRING},
246        {"scores": _FLOAT},
247        signature_constants.CLASSIFY_METHOD_NAME)
248
249    self._assertValidSignature(
250        {"inputs": _STRING},
251        {"outputs": _FLOAT},
252        signature_constants.REGRESS_METHOD_NAME)
253
254    self._assertValidSignature(
255        {"foo": _STRING, "bar": _FLOAT},
256        {"baz": _STRING, "qux": _FLOAT},
257        signature_constants.PREDICT_METHOD_NAME)
258
259  def testInvalidMethodNameSignatureIsRejected(self):
260    # WRONG METHOD
261    self._assertInvalidSignature(
262        {"inputs": _STRING},
263        {"classes": _STRING, "scores": _FLOAT},
264        "WRONG method name")
265
266  def testInvalidClassificationSignaturesAreRejected(self):
267    # CLASSIFY: wrong types
268    self._assertInvalidSignature(
269        {"inputs": _FLOAT},
270        {"classes": _STRING, "scores": _FLOAT},
271        signature_constants.CLASSIFY_METHOD_NAME)
272
273    self._assertInvalidSignature(
274        {"inputs": _STRING},
275        {"classes": _FLOAT, "scores": _FLOAT},
276        signature_constants.CLASSIFY_METHOD_NAME)
277
278    self._assertInvalidSignature(
279        {"inputs": _STRING},
280        {"classes": _STRING, "scores": _STRING},
281        signature_constants.CLASSIFY_METHOD_NAME)
282
283    # CLASSIFY: wrong keys
284    self._assertInvalidSignature(
285        {},
286        {"classes": _STRING, "scores": _FLOAT},
287        signature_constants.CLASSIFY_METHOD_NAME)
288
289    self._assertInvalidSignature(
290        {"inputs_WRONG": _STRING},
291        {"classes": _STRING, "scores": _FLOAT},
292        signature_constants.CLASSIFY_METHOD_NAME)
293
294    self._assertInvalidSignature(
295        {"inputs": _STRING},
296        {"classes_WRONG": _STRING, "scores": _FLOAT},
297        signature_constants.CLASSIFY_METHOD_NAME)
298
299    self._assertInvalidSignature(
300        {"inputs": _STRING},
301        {},
302        signature_constants.CLASSIFY_METHOD_NAME)
303
304    self._assertInvalidSignature(
305        {"inputs": _STRING},
306        {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING},
307        signature_constants.CLASSIFY_METHOD_NAME)
308
309  def testInvalidRegressionSignaturesAreRejected(self):
310    # REGRESS: wrong types
311    self._assertInvalidSignature(
312        {"inputs": _FLOAT},
313        {"outputs": _FLOAT},
314        signature_constants.REGRESS_METHOD_NAME)
315
316    self._assertInvalidSignature(
317        {"inputs": _STRING},
318        {"outputs": _STRING},
319        signature_constants.REGRESS_METHOD_NAME)
320
321    # REGRESS: wrong keys
322    self._assertInvalidSignature(
323        {},
324        {"outputs": _FLOAT},
325        signature_constants.REGRESS_METHOD_NAME)
326
327    self._assertInvalidSignature(
328        {"inputs_WRONG": _STRING},
329        {"outputs": _FLOAT},
330        signature_constants.REGRESS_METHOD_NAME)
331
332    self._assertInvalidSignature(
333        {"inputs": _STRING},
334        {"outputs_WRONG": _FLOAT},
335        signature_constants.REGRESS_METHOD_NAME)
336
337    self._assertInvalidSignature(
338        {"inputs": _STRING},
339        {},
340        signature_constants.REGRESS_METHOD_NAME)
341
342    self._assertInvalidSignature(
343        {"inputs": _STRING},
344        {"outputs": _FLOAT, "extra_WRONG": _STRING},
345        signature_constants.REGRESS_METHOD_NAME)
346
347  def testInvalidPredictSignaturesAreRejected(self):
348    # PREDICT: wrong keys
349    self._assertInvalidSignature(
350        {},
351        {"baz": _STRING, "qux": _FLOAT},
352        signature_constants.PREDICT_METHOD_NAME)
353
354    self._assertInvalidSignature(
355        {"foo": _STRING, "bar": _FLOAT},
356        {},
357        signature_constants.PREDICT_METHOD_NAME)
358
359if __name__ == "__main__":
360  test.main()
361