1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SignatureDef utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import types_pb2 22from tensorflow.core.protobuf import meta_graph_pb2 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import test 29from tensorflow.python.saved_model import signature_constants 30from tensorflow.python.saved_model import signature_def_utils_impl 31from tensorflow.python.saved_model import utils 32 33 34# We'll reuse the same tensor_infos in multiple contexts just for the tests. 35# The validator doesn't check shapes so we just omit them. 36_STRING = meta_graph_pb2.TensorInfo( 37 name="foobar", 38 dtype=dtypes.string.as_datatype_enum 39) 40 41 42_FLOAT = meta_graph_pb2.TensorInfo( 43 name="foobar", 44 dtype=dtypes.float32.as_datatype_enum 45) 46 47 48def _make_signature(inputs, outputs, name=None): 49 input_info = { 50 input_name: utils.build_tensor_info(tensor) 51 for input_name, tensor in inputs.items() 52 } 53 output_info = { 54 output_name: utils.build_tensor_info(tensor) 55 for output_name, tensor in outputs.items() 56 } 57 return signature_def_utils_impl.build_signature_def(input_info, output_info, 58 name) 59 60 61class SignatureDefUtilsTest(test.TestCase): 62 63 @test_util.run_deprecated_v1 64 def testBuildSignatureDef(self): 65 x = array_ops.placeholder(dtypes.float32, 1, name="x") 66 x_tensor_info = utils.build_tensor_info(x) 67 inputs = dict() 68 inputs["foo-input"] = x_tensor_info 69 70 y = array_ops.placeholder(dtypes.float32, name="y") 71 y_tensor_info = utils.build_tensor_info(y) 72 outputs = dict() 73 outputs["foo-output"] = y_tensor_info 74 75 signature_def = signature_def_utils_impl.build_signature_def( 76 inputs, outputs, "foo-method-name") 77 self.assertEqual("foo-method-name", signature_def.method_name) 78 79 # Check inputs in signature def. 80 self.assertEqual(1, len(signature_def.inputs)) 81 x_tensor_info_actual = signature_def.inputs["foo-input"] 82 self.assertEqual("x:0", x_tensor_info_actual.name) 83 self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) 84 self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) 85 self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) 86 87 # Check outputs in signature def. 88 self.assertEqual(1, len(signature_def.outputs)) 89 y_tensor_info_actual = signature_def.outputs["foo-output"] 90 self.assertEqual("y:0", y_tensor_info_actual.name) 91 self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) 92 self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) 93 94 @test_util.run_deprecated_v1 95 def testRegressionSignatureDef(self): 96 input1 = constant_op.constant("a", name="input-1") 97 output1 = constant_op.constant(2.2, name="output-1") 98 signature_def = signature_def_utils_impl.regression_signature_def( 99 input1, output1) 100 101 self.assertEqual(signature_constants.REGRESS_METHOD_NAME, 102 signature_def.method_name) 103 104 # Check inputs in signature def. 105 self.assertEqual(1, len(signature_def.inputs)) 106 x_tensor_info_actual = ( 107 signature_def.inputs[signature_constants.REGRESS_INPUTS]) 108 self.assertEqual("input-1:0", x_tensor_info_actual.name) 109 self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) 110 self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) 111 112 # Check outputs in signature def. 113 self.assertEqual(1, len(signature_def.outputs)) 114 y_tensor_info_actual = ( 115 signature_def.outputs[signature_constants.REGRESS_OUTPUTS]) 116 self.assertEqual("output-1:0", y_tensor_info_actual.name) 117 self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) 118 self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) 119 120 @test_util.run_deprecated_v1 121 def testClassificationSignatureDef(self): 122 input1 = constant_op.constant("a", name="input-1") 123 output1 = constant_op.constant("b", name="output-1") 124 output2 = constant_op.constant(3.3, name="output-2") 125 signature_def = signature_def_utils_impl.classification_signature_def( 126 input1, output1, output2) 127 128 self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, 129 signature_def.method_name) 130 131 # Check inputs in signature def. 132 self.assertEqual(1, len(signature_def.inputs)) 133 x_tensor_info_actual = ( 134 signature_def.inputs[signature_constants.CLASSIFY_INPUTS]) 135 self.assertEqual("input-1:0", x_tensor_info_actual.name) 136 self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) 137 self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) 138 139 # Check outputs in signature def. 140 self.assertEqual(2, len(signature_def.outputs)) 141 classes_tensor_info_actual = ( 142 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES]) 143 self.assertEqual("output-1:0", classes_tensor_info_actual.name) 144 self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype) 145 self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim)) 146 scores_tensor_info_actual = ( 147 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES]) 148 self.assertEqual("output-2:0", scores_tensor_info_actual.name) 149 self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype) 150 self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim)) 151 152 @test_util.run_deprecated_v1 153 def testPredictionSignatureDef(self): 154 input1 = constant_op.constant("a", name="input-1") 155 input2 = constant_op.constant("b", name="input-2") 156 output1 = constant_op.constant("c", name="output-1") 157 output2 = constant_op.constant("d", name="output-2") 158 signature_def = signature_def_utils_impl.predict_signature_def({ 159 "input-1": input1, 160 "input-2": input2 161 }, {"output-1": output1, 162 "output-2": output2}) 163 164 self.assertEqual(signature_constants.PREDICT_METHOD_NAME, 165 signature_def.method_name) 166 167 # Check inputs in signature def. 168 self.assertEqual(2, len(signature_def.inputs)) 169 input1_tensor_info_actual = (signature_def.inputs["input-1"]) 170 self.assertEqual("input-1:0", input1_tensor_info_actual.name) 171 self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) 172 self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) 173 input2_tensor_info_actual = (signature_def.inputs["input-2"]) 174 self.assertEqual("input-2:0", input2_tensor_info_actual.name) 175 self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) 176 self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) 177 178 # Check outputs in signature def. 179 self.assertEqual(2, len(signature_def.outputs)) 180 output1_tensor_info_actual = (signature_def.outputs["output-1"]) 181 self.assertEqual("output-1:0", output1_tensor_info_actual.name) 182 self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype) 183 self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim)) 184 output2_tensor_info_actual = (signature_def.outputs["output-2"]) 185 self.assertEqual("output-2:0", output2_tensor_info_actual.name) 186 self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) 187 self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) 188 189 @test_util.run_deprecated_v1 190 def testTrainSignatureDef(self): 191 self._testSupervisedSignatureDef( 192 signature_def_utils_impl.supervised_train_signature_def, 193 signature_constants.SUPERVISED_TRAIN_METHOD_NAME) 194 195 @test_util.run_deprecated_v1 196 def testEvalSignatureDef(self): 197 self._testSupervisedSignatureDef( 198 signature_def_utils_impl.supervised_eval_signature_def, 199 signature_constants.SUPERVISED_EVAL_METHOD_NAME) 200 201 def _testSupervisedSignatureDef(self, fn_to_test, method_name): 202 inputs = { 203 "input-1": constant_op.constant("a", name="input-1"), 204 "input-2": constant_op.constant("b", name="input-2"), 205 } 206 loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} 207 predictions = { 208 "classes": constant_op.constant([100], name="classes"), 209 } 210 metrics_val = constant_op.constant(100.0, name="metrics_val") 211 metrics = { 212 "metrics/value": metrics_val, 213 "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), 214 } 215 216 signature_def = fn_to_test(inputs, loss, predictions, metrics) 217 218 self.assertEqual(method_name, signature_def.method_name) 219 220 # Check inputs in signature def. 221 self.assertEqual(2, len(signature_def.inputs)) 222 input1_tensor_info_actual = (signature_def.inputs["input-1"]) 223 self.assertEqual("input-1:0", input1_tensor_info_actual.name) 224 self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) 225 self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) 226 input2_tensor_info_actual = (signature_def.inputs["input-2"]) 227 self.assertEqual("input-2:0", input2_tensor_info_actual.name) 228 self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) 229 self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) 230 231 # Check outputs in signature def. 232 self.assertEqual(4, len(signature_def.outputs)) 233 self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name) 234 self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype) 235 236 self.assertEqual("classes:0", signature_def.outputs["classes"].name) 237 self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim)) 238 239 self.assertEqual( 240 "metrics_val:0", signature_def.outputs["metrics/value"].name) 241 self.assertEqual( 242 types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) 243 244 self.assertEqual( 245 "metrics_op:0", signature_def.outputs["metrics/update_op"].name) 246 self.assertEqual( 247 types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) 248 249 @test_util.run_deprecated_v1 250 def testTrainSignatureDefMissingInputs(self): 251 self._testSupervisedSignatureDefMissingInputs( 252 signature_def_utils_impl.supervised_train_signature_def, 253 signature_constants.SUPERVISED_TRAIN_METHOD_NAME) 254 255 @test_util.run_deprecated_v1 256 def testEvalSignatureDefMissingInputs(self): 257 self._testSupervisedSignatureDefMissingInputs( 258 signature_def_utils_impl.supervised_eval_signature_def, 259 signature_constants.SUPERVISED_EVAL_METHOD_NAME) 260 261 def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name): 262 inputs = { 263 "input-1": constant_op.constant("a", name="input-1"), 264 "input-2": constant_op.constant("b", name="input-2"), 265 } 266 loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} 267 predictions = { 268 "classes": constant_op.constant([100], name="classes"), 269 } 270 metrics_val = constant_op.constant(100, name="metrics_val") 271 metrics = { 272 "metrics/value": metrics_val, 273 "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), 274 } 275 276 with self.assertRaises(ValueError): 277 signature_def = fn_to_test( 278 {}, loss=loss, predictions=predictions, metrics=metrics) 279 280 signature_def = fn_to_test(inputs, loss=loss) 281 self.assertEqual(method_name, signature_def.method_name) 282 self.assertEqual(1, len(signature_def.outputs)) 283 284 signature_def = fn_to_test(inputs, metrics=metrics, loss=loss) 285 self.assertEqual(method_name, signature_def.method_name) 286 self.assertEqual(3, len(signature_def.outputs)) 287 288 def _assertValidSignature(self, inputs, outputs, method_name): 289 signature_def = signature_def_utils_impl.build_signature_def( 290 inputs, outputs, method_name) 291 self.assertTrue( 292 signature_def_utils_impl.is_valid_signature(signature_def)) 293 294 def _assertInvalidSignature(self, inputs, outputs, method_name): 295 signature_def = signature_def_utils_impl.build_signature_def( 296 inputs, outputs, method_name) 297 self.assertFalse( 298 signature_def_utils_impl.is_valid_signature(signature_def)) 299 300 def testValidSignaturesAreAccepted(self): 301 self._assertValidSignature( 302 {"inputs": _STRING}, 303 {"classes": _STRING, "scores": _FLOAT}, 304 signature_constants.CLASSIFY_METHOD_NAME) 305 306 self._assertValidSignature( 307 {"inputs": _STRING}, 308 {"classes": _STRING}, 309 signature_constants.CLASSIFY_METHOD_NAME) 310 311 self._assertValidSignature( 312 {"inputs": _STRING}, 313 {"scores": _FLOAT}, 314 signature_constants.CLASSIFY_METHOD_NAME) 315 316 self._assertValidSignature( 317 {"inputs": _STRING}, 318 {"outputs": _FLOAT}, 319 signature_constants.REGRESS_METHOD_NAME) 320 321 self._assertValidSignature( 322 {"foo": _STRING, "bar": _FLOAT}, 323 {"baz": _STRING, "qux": _FLOAT}, 324 signature_constants.PREDICT_METHOD_NAME) 325 326 def testInvalidMethodNameSignatureIsRejected(self): 327 # WRONG METHOD 328 self._assertInvalidSignature( 329 {"inputs": _STRING}, 330 {"classes": _STRING, "scores": _FLOAT}, 331 "WRONG method name") 332 333 def testInvalidClassificationSignaturesAreRejected(self): 334 # CLASSIFY: wrong types 335 self._assertInvalidSignature( 336 {"inputs": _FLOAT}, 337 {"classes": _STRING, "scores": _FLOAT}, 338 signature_constants.CLASSIFY_METHOD_NAME) 339 340 self._assertInvalidSignature( 341 {"inputs": _STRING}, 342 {"classes": _FLOAT, "scores": _FLOAT}, 343 signature_constants.CLASSIFY_METHOD_NAME) 344 345 self._assertInvalidSignature( 346 {"inputs": _STRING}, 347 {"classes": _STRING, "scores": _STRING}, 348 signature_constants.CLASSIFY_METHOD_NAME) 349 350 # CLASSIFY: wrong keys 351 self._assertInvalidSignature( 352 {}, 353 {"classes": _STRING, "scores": _FLOAT}, 354 signature_constants.CLASSIFY_METHOD_NAME) 355 356 self._assertInvalidSignature( 357 {"inputs_WRONG": _STRING}, 358 {"classes": _STRING, "scores": _FLOAT}, 359 signature_constants.CLASSIFY_METHOD_NAME) 360 361 self._assertInvalidSignature( 362 {"inputs": _STRING}, 363 {"classes_WRONG": _STRING, "scores": _FLOAT}, 364 signature_constants.CLASSIFY_METHOD_NAME) 365 366 self._assertInvalidSignature( 367 {"inputs": _STRING}, 368 {}, 369 signature_constants.CLASSIFY_METHOD_NAME) 370 371 self._assertInvalidSignature( 372 {"inputs": _STRING}, 373 {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING}, 374 signature_constants.CLASSIFY_METHOD_NAME) 375 376 def testInvalidRegressionSignaturesAreRejected(self): 377 # REGRESS: wrong types 378 self._assertInvalidSignature( 379 {"inputs": _FLOAT}, 380 {"outputs": _FLOAT}, 381 signature_constants.REGRESS_METHOD_NAME) 382 383 self._assertInvalidSignature( 384 {"inputs": _STRING}, 385 {"outputs": _STRING}, 386 signature_constants.REGRESS_METHOD_NAME) 387 388 # REGRESS: wrong keys 389 self._assertInvalidSignature( 390 {}, 391 {"outputs": _FLOAT}, 392 signature_constants.REGRESS_METHOD_NAME) 393 394 self._assertInvalidSignature( 395 {"inputs_WRONG": _STRING}, 396 {"outputs": _FLOAT}, 397 signature_constants.REGRESS_METHOD_NAME) 398 399 self._assertInvalidSignature( 400 {"inputs": _STRING}, 401 {"outputs_WRONG": _FLOAT}, 402 signature_constants.REGRESS_METHOD_NAME) 403 404 self._assertInvalidSignature( 405 {"inputs": _STRING}, 406 {}, 407 signature_constants.REGRESS_METHOD_NAME) 408 409 self._assertInvalidSignature( 410 {"inputs": _STRING}, 411 {"outputs": _FLOAT, "extra_WRONG": _STRING}, 412 signature_constants.REGRESS_METHOD_NAME) 413 414 def testInvalidPredictSignaturesAreRejected(self): 415 # PREDICT: wrong keys 416 self._assertInvalidSignature( 417 {}, 418 {"baz": _STRING, "qux": _FLOAT}, 419 signature_constants.PREDICT_METHOD_NAME) 420 421 self._assertInvalidSignature( 422 {"foo": _STRING, "bar": _FLOAT}, 423 {}, 424 signature_constants.PREDICT_METHOD_NAME) 425 426 @test_util.run_v1_only("b/120545219") 427 def testOpSignatureDef(self): 428 key = "adding_1_and_2_key" 429 add_op = math_ops.add(1, 2, name="adding_1_and_2") 430 signature_def = signature_def_utils_impl.op_signature_def(add_op, key) 431 self.assertIn(key, signature_def.outputs) 432 self.assertEqual(add_op.name, signature_def.outputs[key].name) 433 434 @test_util.run_v1_only("b/120545219") 435 def testLoadOpFromSignatureDef(self): 436 key = "adding_1_and_2_key" 437 add_op = math_ops.add(1, 2, name="adding_1_and_2") 438 signature_def = signature_def_utils_impl.op_signature_def(add_op, key) 439 440 self.assertEqual( 441 add_op, 442 signature_def_utils_impl.load_op_from_signature_def(signature_def, key)) 443 444 445if __name__ == "__main__": 446 test.main() 447