1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SignatureDef utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import types_pb2 22from tensorflow.core.protobuf import meta_graph_pb2 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.ops import array_ops 26from tensorflow.python.platform import test 27from tensorflow.python.saved_model import signature_constants 28from tensorflow.python.saved_model import signature_def_utils_impl 29from tensorflow.python.saved_model import utils 30 31 32# We'll reuse the same tensor_infos in multiple contexts just for the tests. 33# The validator doesn't check shapes so we just omit them. 34_STRING = meta_graph_pb2.TensorInfo( 35 name="foobar", 36 dtype=dtypes.string.as_datatype_enum 37) 38 39 40_FLOAT = meta_graph_pb2.TensorInfo( 41 name="foobar", 42 dtype=dtypes.float32.as_datatype_enum 43) 44 45 46def _make_signature(inputs, outputs, name=None): 47 input_info = { 48 input_name: utils.build_tensor_info(tensor) 49 for input_name, tensor in inputs.items() 50 } 51 output_info = { 52 output_name: utils.build_tensor_info(tensor) 53 for output_name, tensor in outputs.items() 54 } 55 return signature_def_utils_impl.build_signature_def(input_info, output_info, 56 name) 57 58 59class SignatureDefUtilsTest(test.TestCase): 60 61 def testBuildSignatureDef(self): 62 x = array_ops.placeholder(dtypes.float32, 1, name="x") 63 x_tensor_info = utils.build_tensor_info(x) 64 inputs = dict() 65 inputs["foo-input"] = x_tensor_info 66 67 y = array_ops.placeholder(dtypes.float32, name="y") 68 y_tensor_info = utils.build_tensor_info(y) 69 outputs = dict() 70 outputs["foo-output"] = y_tensor_info 71 72 signature_def = signature_def_utils_impl.build_signature_def( 73 inputs, outputs, "foo-method-name") 74 self.assertEqual("foo-method-name", signature_def.method_name) 75 76 # Check inputs in signature def. 77 self.assertEqual(1, len(signature_def.inputs)) 78 x_tensor_info_actual = signature_def.inputs["foo-input"] 79 self.assertEqual("x:0", x_tensor_info_actual.name) 80 self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) 81 self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) 82 self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) 83 84 # Check outputs in signature def. 85 self.assertEqual(1, len(signature_def.outputs)) 86 y_tensor_info_actual = signature_def.outputs["foo-output"] 87 self.assertEqual("y:0", y_tensor_info_actual.name) 88 self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) 89 self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) 90 91 def testRegressionSignatureDef(self): 92 input1 = constant_op.constant("a", name="input-1") 93 output1 = constant_op.constant(2.2, name="output-1") 94 signature_def = signature_def_utils_impl.regression_signature_def( 95 input1, output1) 96 97 self.assertEqual(signature_constants.REGRESS_METHOD_NAME, 98 signature_def.method_name) 99 100 # Check inputs in signature def. 101 self.assertEqual(1, len(signature_def.inputs)) 102 x_tensor_info_actual = ( 103 signature_def.inputs[signature_constants.REGRESS_INPUTS]) 104 self.assertEqual("input-1:0", x_tensor_info_actual.name) 105 self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) 106 self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) 107 108 # Check outputs in signature def. 109 self.assertEqual(1, len(signature_def.outputs)) 110 y_tensor_info_actual = ( 111 signature_def.outputs[signature_constants.REGRESS_OUTPUTS]) 112 self.assertEqual("output-1:0", y_tensor_info_actual.name) 113 self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) 114 self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) 115 116 def testClassificationSignatureDef(self): 117 input1 = constant_op.constant("a", name="input-1") 118 output1 = constant_op.constant("b", name="output-1") 119 output2 = constant_op.constant(3.3, name="output-2") 120 signature_def = signature_def_utils_impl.classification_signature_def( 121 input1, output1, output2) 122 123 self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, 124 signature_def.method_name) 125 126 # Check inputs in signature def. 127 self.assertEqual(1, len(signature_def.inputs)) 128 x_tensor_info_actual = ( 129 signature_def.inputs[signature_constants.CLASSIFY_INPUTS]) 130 self.assertEqual("input-1:0", x_tensor_info_actual.name) 131 self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) 132 self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) 133 134 # Check outputs in signature def. 135 self.assertEqual(2, len(signature_def.outputs)) 136 classes_tensor_info_actual = ( 137 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES]) 138 self.assertEqual("output-1:0", classes_tensor_info_actual.name) 139 self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype) 140 self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim)) 141 scores_tensor_info_actual = ( 142 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES]) 143 self.assertEqual("output-2:0", scores_tensor_info_actual.name) 144 self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype) 145 self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim)) 146 147 def testPredictionSignatureDef(self): 148 input1 = constant_op.constant("a", name="input-1") 149 input2 = constant_op.constant("b", name="input-2") 150 output1 = constant_op.constant("c", name="output-1") 151 output2 = constant_op.constant("d", name="output-2") 152 signature_def = signature_def_utils_impl.predict_signature_def({ 153 "input-1": input1, 154 "input-2": input2 155 }, {"output-1": output1, 156 "output-2": output2}) 157 158 self.assertEqual(signature_constants.PREDICT_METHOD_NAME, 159 signature_def.method_name) 160 161 # Check inputs in signature def. 162 self.assertEqual(2, len(signature_def.inputs)) 163 input1_tensor_info_actual = (signature_def.inputs["input-1"]) 164 self.assertEqual("input-1:0", input1_tensor_info_actual.name) 165 self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) 166 self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) 167 input2_tensor_info_actual = (signature_def.inputs["input-2"]) 168 self.assertEqual("input-2:0", input2_tensor_info_actual.name) 169 self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) 170 self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) 171 172 # Check outputs in signature def. 173 self.assertEqual(2, len(signature_def.outputs)) 174 output1_tensor_info_actual = (signature_def.outputs["output-1"]) 175 self.assertEqual("output-1:0", output1_tensor_info_actual.name) 176 self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype) 177 self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim)) 178 output2_tensor_info_actual = (signature_def.outputs["output-2"]) 179 self.assertEqual("output-2:0", output2_tensor_info_actual.name) 180 self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) 181 self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) 182 183 def testGetShapeAndTypes(self): 184 inputs = { 185 "input-1": constant_op.constant(["a", "b"]), 186 "input-2": array_ops.placeholder(dtypes.float32, [10, 11]), 187 } 188 outputs = { 189 "output-1": array_ops.placeholder(dtypes.float32, [10, 32]), 190 "output-2": constant_op.constant([["b"]]), 191 } 192 signature_def = _make_signature(inputs, outputs) 193 self.assertEqual( 194 signature_def_utils_impl.get_signature_def_input_shapes(signature_def), 195 {"input-1": [2], "input-2": [10, 11]}) 196 self.assertEqual( 197 signature_def_utils_impl.get_signature_def_output_shapes(signature_def), 198 {"output-1": [10, 32], "output-2": [1, 1]}) 199 self.assertEqual( 200 signature_def_utils_impl.get_signature_def_input_types(signature_def), 201 {"input-1": dtypes.string, "input-2": dtypes.float32}) 202 self.assertEqual( 203 signature_def_utils_impl.get_signature_def_output_types(signature_def), 204 {"output-1": dtypes.float32, "output-2": dtypes.string}) 205 206 def testGetNonFullySpecifiedShapes(self): 207 outputs = { 208 "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]), 209 "output-2": array_ops.sparse_placeholder(dtypes.float32), 210 } 211 signature_def = _make_signature({}, outputs) 212 shapes = signature_def_utils_impl.get_signature_def_output_shapes( 213 signature_def) 214 self.assertEqual(len(shapes), 2) 215 # Must compare shapes with as_list() since 2 equivalent non-fully defined 216 # shapes are not equal to each other. 217 self.assertEqual(shapes["output-1"].as_list(), [None, 10, None]) 218 # Must compare `dims` since its an unknown shape. 219 self.assertEqual(shapes["output-2"].dims, None) 220 221 def _assertValidSignature(self, inputs, outputs, method_name): 222 signature_def = signature_def_utils_impl.build_signature_def( 223 inputs, outputs, method_name) 224 self.assertTrue( 225 signature_def_utils_impl.is_valid_signature(signature_def)) 226 227 def _assertInvalidSignature(self, inputs, outputs, method_name): 228 signature_def = signature_def_utils_impl.build_signature_def( 229 inputs, outputs, method_name) 230 self.assertFalse( 231 signature_def_utils_impl.is_valid_signature(signature_def)) 232 233 def testValidSignaturesAreAccepted(self): 234 self._assertValidSignature( 235 {"inputs": _STRING}, 236 {"classes": _STRING, "scores": _FLOAT}, 237 signature_constants.CLASSIFY_METHOD_NAME) 238 239 self._assertValidSignature( 240 {"inputs": _STRING}, 241 {"classes": _STRING}, 242 signature_constants.CLASSIFY_METHOD_NAME) 243 244 self._assertValidSignature( 245 {"inputs": _STRING}, 246 {"scores": _FLOAT}, 247 signature_constants.CLASSIFY_METHOD_NAME) 248 249 self._assertValidSignature( 250 {"inputs": _STRING}, 251 {"outputs": _FLOAT}, 252 signature_constants.REGRESS_METHOD_NAME) 253 254 self._assertValidSignature( 255 {"foo": _STRING, "bar": _FLOAT}, 256 {"baz": _STRING, "qux": _FLOAT}, 257 signature_constants.PREDICT_METHOD_NAME) 258 259 def testInvalidMethodNameSignatureIsRejected(self): 260 # WRONG METHOD 261 self._assertInvalidSignature( 262 {"inputs": _STRING}, 263 {"classes": _STRING, "scores": _FLOAT}, 264 "WRONG method name") 265 266 def testInvalidClassificationSignaturesAreRejected(self): 267 # CLASSIFY: wrong types 268 self._assertInvalidSignature( 269 {"inputs": _FLOAT}, 270 {"classes": _STRING, "scores": _FLOAT}, 271 signature_constants.CLASSIFY_METHOD_NAME) 272 273 self._assertInvalidSignature( 274 {"inputs": _STRING}, 275 {"classes": _FLOAT, "scores": _FLOAT}, 276 signature_constants.CLASSIFY_METHOD_NAME) 277 278 self._assertInvalidSignature( 279 {"inputs": _STRING}, 280 {"classes": _STRING, "scores": _STRING}, 281 signature_constants.CLASSIFY_METHOD_NAME) 282 283 # CLASSIFY: wrong keys 284 self._assertInvalidSignature( 285 {}, 286 {"classes": _STRING, "scores": _FLOAT}, 287 signature_constants.CLASSIFY_METHOD_NAME) 288 289 self._assertInvalidSignature( 290 {"inputs_WRONG": _STRING}, 291 {"classes": _STRING, "scores": _FLOAT}, 292 signature_constants.CLASSIFY_METHOD_NAME) 293 294 self._assertInvalidSignature( 295 {"inputs": _STRING}, 296 {"classes_WRONG": _STRING, "scores": _FLOAT}, 297 signature_constants.CLASSIFY_METHOD_NAME) 298 299 self._assertInvalidSignature( 300 {"inputs": _STRING}, 301 {}, 302 signature_constants.CLASSIFY_METHOD_NAME) 303 304 self._assertInvalidSignature( 305 {"inputs": _STRING}, 306 {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING}, 307 signature_constants.CLASSIFY_METHOD_NAME) 308 309 def testInvalidRegressionSignaturesAreRejected(self): 310 # REGRESS: wrong types 311 self._assertInvalidSignature( 312 {"inputs": _FLOAT}, 313 {"outputs": _FLOAT}, 314 signature_constants.REGRESS_METHOD_NAME) 315 316 self._assertInvalidSignature( 317 {"inputs": _STRING}, 318 {"outputs": _STRING}, 319 signature_constants.REGRESS_METHOD_NAME) 320 321 # REGRESS: wrong keys 322 self._assertInvalidSignature( 323 {}, 324 {"outputs": _FLOAT}, 325 signature_constants.REGRESS_METHOD_NAME) 326 327 self._assertInvalidSignature( 328 {"inputs_WRONG": _STRING}, 329 {"outputs": _FLOAT}, 330 signature_constants.REGRESS_METHOD_NAME) 331 332 self._assertInvalidSignature( 333 {"inputs": _STRING}, 334 {"outputs_WRONG": _FLOAT}, 335 signature_constants.REGRESS_METHOD_NAME) 336 337 self._assertInvalidSignature( 338 {"inputs": _STRING}, 339 {}, 340 signature_constants.REGRESS_METHOD_NAME) 341 342 self._assertInvalidSignature( 343 {"inputs": _STRING}, 344 {"outputs": _FLOAT, "extra_WRONG": _STRING}, 345 signature_constants.REGRESS_METHOD_NAME) 346 347 def testInvalidPredictSignaturesAreRejected(self): 348 # PREDICT: wrong keys 349 self._assertInvalidSignature( 350 {}, 351 {"baz": _STRING, "qux": _FLOAT}, 352 signature_constants.PREDICT_METHOD_NAME) 353 354 self._assertInvalidSignature( 355 {"foo": _STRING, "bar": _FLOAT}, 356 {}, 357 signature_constants.PREDICT_METHOD_NAME) 358 359if __name__ == "__main__": 360 test.main() 361