1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests of utilities supporting export to SavedModel.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import tempfile 22import time 23 24from tensorflow.contrib.layers.python.layers import feature_column as fc 25from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib 26from tensorflow.contrib.learn.python.learn.estimators import constants 27from tensorflow.contrib.learn.python.learn.estimators import estimator 28from tensorflow.contrib.learn.python.learn.estimators import model_fn 29from tensorflow.contrib.learn.python.learn.utils import input_fn_utils 30from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils 31from tensorflow.core.framework import tensor_shape_pb2 32from tensorflow.core.framework import types_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 34from tensorflow.python.estimator import estimator as core_estimator 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.ops import array_ops 38from tensorflow.python.platform import gfile 39from tensorflow.python.platform import test 40from tensorflow.python.saved_model import signature_constants 41from tensorflow.python.saved_model import signature_def_utils 42from tensorflow.python.util import compat 43 44 45class TestEstimator(estimator.Estimator): 46 47 def __init__(self, *args, **kwargs): 48 super(TestEstimator, self).__init__(*args, **kwargs) 49 self.last_exported_checkpoint = "" 50 self.last_exported_dir = "" 51 52 # @Override 53 def export_savedmodel(self, 54 export_dir, 55 serving_input_fn, 56 default_output_alternative_key=None, 57 assets_extra=None, 58 as_text=False, 59 checkpoint_path=None, 60 strip_default_attrs=False): 61 62 if not os.path.exists(export_dir): 63 os.makedirs(export_dir) 64 65 open(os.path.join(export_dir, "placeholder.txt"), "a").close() 66 67 self.last_exported_checkpoint = checkpoint_path 68 self.last_exported_dir = export_dir 69 70 return export_dir 71 72 73class SavedModelExportUtilsTest(test.TestCase): 74 75 def test_build_standardized_signature_def_regression(self): 76 input_tensors = { 77 "input-1": 78 array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") 79 } 80 output_tensors = { 81 "output-1": 82 array_ops.placeholder(dtypes.float32, 1, name="output-tensor-1") 83 } 84 problem_type = constants.ProblemType.LINEAR_REGRESSION 85 actual_signature_def = ( 86 saved_model_export_utils.build_standardized_signature_def( 87 input_tensors, output_tensors, problem_type)) 88 expected_signature_def = meta_graph_pb2.SignatureDef() 89 shape = tensor_shape_pb2.TensorShapeProto( 90 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 91 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 92 dtype_string = types_pb2.DataType.Value("DT_STRING") 93 expected_signature_def.inputs[signature_constants.REGRESS_INPUTS].CopyFrom( 94 meta_graph_pb2.TensorInfo( 95 name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) 96 expected_signature_def.outputs[ 97 signature_constants.REGRESS_OUTPUTS].CopyFrom( 98 meta_graph_pb2.TensorInfo( 99 name="output-tensor-1:0", dtype=dtype_float, 100 tensor_shape=shape)) 101 102 expected_signature_def.method_name = signature_constants.REGRESS_METHOD_NAME 103 self.assertEqual(actual_signature_def, expected_signature_def) 104 105 def test_build_standardized_signature_def_classification(self): 106 """Tests classification with one output tensor.""" 107 input_tensors = { 108 "input-1": 109 array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") 110 } 111 output_tensors = { 112 "output-1": 113 array_ops.placeholder(dtypes.string, 1, name="output-tensor-1") 114 } 115 problem_type = constants.ProblemType.CLASSIFICATION 116 actual_signature_def = ( 117 saved_model_export_utils.build_standardized_signature_def( 118 input_tensors, output_tensors, problem_type)) 119 expected_signature_def = meta_graph_pb2.SignatureDef() 120 shape = tensor_shape_pb2.TensorShapeProto( 121 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 122 dtype_string = types_pb2.DataType.Value("DT_STRING") 123 expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( 124 meta_graph_pb2.TensorInfo( 125 name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) 126 expected_signature_def.outputs[ 127 signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( 128 meta_graph_pb2.TensorInfo( 129 name="output-tensor-1:0", 130 dtype=dtype_string, 131 tensor_shape=shape)) 132 133 expected_signature_def.method_name = ( 134 signature_constants.CLASSIFY_METHOD_NAME) 135 self.assertEqual(actual_signature_def, expected_signature_def) 136 137 def test_build_standardized_signature_def_classification2(self): 138 """Tests multiple output tensors that include classes and probabilities.""" 139 input_tensors = { 140 "input-1": 141 array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") 142 } 143 output_tensors = { 144 "classes": 145 array_ops.placeholder( 146 dtypes.string, 1, name="output-tensor-classes"), 147 # Will be used for CLASSIFY_OUTPUT_SCORES. 148 "probabilities": 149 array_ops.placeholder( 150 dtypes.float32, 1, name="output-tensor-proba"), 151 "logits": 152 array_ops.placeholder( 153 dtypes.float32, 1, name="output-tensor-logits-unused"), 154 } 155 problem_type = constants.ProblemType.CLASSIFICATION 156 actual_signature_def = ( 157 saved_model_export_utils.build_standardized_signature_def( 158 input_tensors, output_tensors, problem_type)) 159 expected_signature_def = meta_graph_pb2.SignatureDef() 160 shape = tensor_shape_pb2.TensorShapeProto( 161 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 162 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 163 dtype_string = types_pb2.DataType.Value("DT_STRING") 164 expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( 165 meta_graph_pb2.TensorInfo( 166 name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) 167 expected_signature_def.outputs[ 168 signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( 169 meta_graph_pb2.TensorInfo( 170 name="output-tensor-classes:0", 171 dtype=dtype_string, 172 tensor_shape=shape)) 173 expected_signature_def.outputs[ 174 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 175 meta_graph_pb2.TensorInfo( 176 name="output-tensor-proba:0", 177 dtype=dtype_float, 178 tensor_shape=shape)) 179 180 expected_signature_def.method_name = ( 181 signature_constants.CLASSIFY_METHOD_NAME) 182 self.assertEqual(actual_signature_def, expected_signature_def) 183 184 def test_build_standardized_signature_def_classification3(self): 185 """Tests multiple output tensors that include classes and scores.""" 186 input_tensors = { 187 "input-1": 188 array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") 189 } 190 output_tensors = { 191 "classes": 192 array_ops.placeholder( 193 dtypes.string, 1, name="output-tensor-classes"), 194 "scores": 195 array_ops.placeholder( 196 dtypes.float32, 1, name="output-tensor-scores"), 197 "logits": 198 array_ops.placeholder( 199 dtypes.float32, 1, name="output-tensor-logits-unused"), 200 } 201 problem_type = constants.ProblemType.CLASSIFICATION 202 actual_signature_def = ( 203 saved_model_export_utils.build_standardized_signature_def( 204 input_tensors, output_tensors, problem_type)) 205 expected_signature_def = meta_graph_pb2.SignatureDef() 206 shape = tensor_shape_pb2.TensorShapeProto( 207 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 208 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 209 dtype_string = types_pb2.DataType.Value("DT_STRING") 210 expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( 211 meta_graph_pb2.TensorInfo( 212 name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) 213 expected_signature_def.outputs[ 214 signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( 215 meta_graph_pb2.TensorInfo( 216 name="output-tensor-classes:0", 217 dtype=dtype_string, 218 tensor_shape=shape)) 219 expected_signature_def.outputs[ 220 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 221 meta_graph_pb2.TensorInfo( 222 name="output-tensor-scores:0", 223 dtype=dtype_float, 224 tensor_shape=shape)) 225 226 expected_signature_def.method_name = ( 227 signature_constants.CLASSIFY_METHOD_NAME) 228 self.assertEqual(actual_signature_def, expected_signature_def) 229 230 def test_build_standardized_signature_def_classification4(self): 231 """Tests classification without classes tensor.""" 232 input_tensors = { 233 "input-1": 234 array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") 235 } 236 output_tensors = { 237 "probabilities": 238 array_ops.placeholder( 239 dtypes.float32, 1, name="output-tensor-proba"), 240 "logits": 241 array_ops.placeholder( 242 dtypes.float32, 1, name="output-tensor-logits-unused"), 243 } 244 problem_type = constants.ProblemType.CLASSIFICATION 245 actual_signature_def = ( 246 saved_model_export_utils.build_standardized_signature_def( 247 input_tensors, output_tensors, problem_type)) 248 expected_signature_def = meta_graph_pb2.SignatureDef() 249 shape = tensor_shape_pb2.TensorShapeProto( 250 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 251 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 252 dtype_string = types_pb2.DataType.Value("DT_STRING") 253 expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( 254 meta_graph_pb2.TensorInfo( 255 name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) 256 expected_signature_def.outputs[ 257 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 258 meta_graph_pb2.TensorInfo( 259 name="output-tensor-proba:0", 260 dtype=dtype_float, 261 tensor_shape=shape)) 262 263 expected_signature_def.method_name = ( 264 signature_constants.CLASSIFY_METHOD_NAME) 265 self.assertEqual(actual_signature_def, expected_signature_def) 266 267 def test_build_standardized_signature_def_classification5(self): 268 """Tests multiple output tensors that include integer classes and scores. 269 270 Integer classes are dropped out, because Servo classification can only serve 271 string classes. So, only scores are present in the signature. 272 """ 273 input_tensors = { 274 "input-1": 275 array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") 276 } 277 output_tensors = { 278 "classes": 279 array_ops.placeholder( 280 dtypes.int64, 1, name="output-tensor-classes"), 281 "scores": 282 array_ops.placeholder( 283 dtypes.float32, 1, name="output-tensor-scores"), 284 "logits": 285 array_ops.placeholder( 286 dtypes.float32, 1, name="output-tensor-logits-unused"), 287 } 288 problem_type = constants.ProblemType.CLASSIFICATION 289 actual_signature_def = ( 290 saved_model_export_utils.build_standardized_signature_def( 291 input_tensors, output_tensors, problem_type)) 292 expected_signature_def = meta_graph_pb2.SignatureDef() 293 shape = tensor_shape_pb2.TensorShapeProto( 294 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 295 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 296 dtype_string = types_pb2.DataType.Value("DT_STRING") 297 expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( 298 meta_graph_pb2.TensorInfo( 299 name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) 300 expected_signature_def.outputs[ 301 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 302 meta_graph_pb2.TensorInfo( 303 name="output-tensor-scores:0", 304 dtype=dtype_float, 305 tensor_shape=shape)) 306 307 expected_signature_def.method_name = ( 308 signature_constants.CLASSIFY_METHOD_NAME) 309 self.assertEqual(actual_signature_def, expected_signature_def) 310 311 def test_build_standardized_signature_def_classification6(self): 312 """Tests multiple output tensors that with integer classes and no scores. 313 314 Servo classification cannot serve integer classes, but no scores are 315 available. So, we fall back to predict signature. 316 """ 317 input_tensors = { 318 "input-1": 319 array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") 320 } 321 output_tensors = { 322 "classes": 323 array_ops.placeholder( 324 dtypes.int64, 1, name="output-tensor-classes"), 325 "logits": 326 array_ops.placeholder( 327 dtypes.float32, 1, name="output-tensor-logits"), 328 } 329 problem_type = constants.ProblemType.CLASSIFICATION 330 actual_signature_def = ( 331 saved_model_export_utils.build_standardized_signature_def( 332 input_tensors, output_tensors, problem_type)) 333 expected_signature_def = meta_graph_pb2.SignatureDef() 334 shape = tensor_shape_pb2.TensorShapeProto( 335 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 336 dtype_int64 = types_pb2.DataType.Value("DT_INT64") 337 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 338 dtype_string = types_pb2.DataType.Value("DT_STRING") 339 expected_signature_def.inputs["input-1"].CopyFrom( 340 meta_graph_pb2.TensorInfo( 341 name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) 342 expected_signature_def.outputs["classes"].CopyFrom( 343 meta_graph_pb2.TensorInfo( 344 name="output-tensor-classes:0", 345 dtype=dtype_int64, 346 tensor_shape=shape)) 347 expected_signature_def.outputs["logits"].CopyFrom( 348 meta_graph_pb2.TensorInfo( 349 name="output-tensor-logits:0", 350 dtype=dtype_float, 351 tensor_shape=shape)) 352 353 expected_signature_def.method_name = ( 354 signature_constants.PREDICT_METHOD_NAME) 355 self.assertEqual(actual_signature_def, expected_signature_def) 356 357 def test_get_input_alternatives(self): 358 input_ops = input_fn_utils.InputFnOps("bogus features dict", None, 359 "bogus default input dict") 360 361 input_alternatives, _ = saved_model_export_utils.get_input_alternatives( 362 input_ops) 363 self.assertEqual(input_alternatives[ 364 saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY], 365 "bogus default input dict") 366 # self.assertEqual(input_alternatives[ 367 # saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY], 368 # "bogus features dict") 369 370 def test_get_output_alternatives_explicit_default(self): 371 provided_output_alternatives = { 372 "head-1": (constants.ProblemType.LINEAR_REGRESSION, 373 "bogus output dict"), 374 "head-2": (constants.ProblemType.CLASSIFICATION, "bogus output dict 2"), 375 "head-3": (constants.ProblemType.UNSPECIFIED, "bogus output dict 3"), 376 } 377 model_fn_ops = model_fn.ModelFnOps( 378 model_fn.ModeKeys.INFER, 379 predictions={"some_output": "bogus_tensor"}, 380 output_alternatives=provided_output_alternatives) 381 382 output_alternatives, _ = saved_model_export_utils.get_output_alternatives( 383 model_fn_ops, "head-1") 384 385 self.assertEqual(provided_output_alternatives, output_alternatives) 386 387 def test_get_output_alternatives_wrong_default(self): 388 provided_output_alternatives = { 389 "head-1": (constants.ProblemType.LINEAR_REGRESSION, 390 "bogus output dict"), 391 "head-2": (constants.ProblemType.CLASSIFICATION, "bogus output dict 2"), 392 "head-3": (constants.ProblemType.UNSPECIFIED, "bogus output dict 3"), 393 } 394 model_fn_ops = model_fn.ModelFnOps( 395 model_fn.ModeKeys.INFER, 396 predictions={"some_output": "bogus_tensor"}, 397 output_alternatives=provided_output_alternatives) 398 399 with self.assertRaises(ValueError) as e: 400 saved_model_export_utils.get_output_alternatives(model_fn_ops, "WRONG") 401 402 self.assertEqual("Requested default_output_alternative: WRONG, but " 403 "available output_alternatives are: ['head-1', 'head-2', " 404 "'head-3']", str(e.exception)) 405 406 def test_get_output_alternatives_single_no_default(self): 407 prediction_tensor = constant_op.constant(["bogus"]) 408 provided_output_alternatives = { 409 "head-1": (constants.ProblemType.LINEAR_REGRESSION, { 410 "output": prediction_tensor 411 }), 412 } 413 model_fn_ops = model_fn.ModelFnOps( 414 model_fn.ModeKeys.INFER, 415 predictions=prediction_tensor, 416 output_alternatives=provided_output_alternatives) 417 418 output_alternatives, _ = saved_model_export_utils.get_output_alternatives( 419 model_fn_ops) 420 421 self.assertEqual({ 422 "head-1": (constants.ProblemType.LINEAR_REGRESSION, { 423 "output": prediction_tensor 424 }) 425 }, output_alternatives) 426 427 def test_get_output_alternatives_multi_no_default(self): 428 provided_output_alternatives = { 429 "head-1": (constants.ProblemType.LINEAR_REGRESSION, 430 "bogus output dict"), 431 "head-2": (constants.ProblemType.CLASSIFICATION, "bogus output dict 2"), 432 "head-3": (constants.ProblemType.UNSPECIFIED, "bogus output dict 3"), 433 } 434 model_fn_ops = model_fn.ModelFnOps( 435 model_fn.ModeKeys.INFER, 436 predictions={"some_output": "bogus_tensor"}, 437 output_alternatives=provided_output_alternatives) 438 439 with self.assertRaises(ValueError) as e: 440 saved_model_export_utils.get_output_alternatives(model_fn_ops) 441 442 self.assertEqual("Please specify a default_output_alternative. Available " 443 "output_alternatives are: ['head-1', 'head-2', 'head-3']", 444 str(e.exception)) 445 446 def test_get_output_alternatives_none_provided(self): 447 prediction_tensor = constant_op.constant(["bogus"]) 448 model_fn_ops = model_fn.ModelFnOps( 449 model_fn.ModeKeys.INFER, 450 predictions={"some_output": prediction_tensor}, 451 output_alternatives=None) 452 453 output_alternatives, _ = saved_model_export_utils.get_output_alternatives( 454 model_fn_ops) 455 456 self.assertEqual({ 457 "default_output_alternative": (constants.ProblemType.UNSPECIFIED, { 458 "some_output": prediction_tensor 459 }) 460 }, output_alternatives) 461 462 def test_get_output_alternatives_empty_provided_with_default(self): 463 prediction_tensor = constant_op.constant(["bogus"]) 464 model_fn_ops = model_fn.ModelFnOps( 465 model_fn.ModeKeys.INFER, 466 predictions={"some_output": prediction_tensor}, 467 output_alternatives={}) 468 469 with self.assertRaises(ValueError) as e: 470 saved_model_export_utils.get_output_alternatives(model_fn_ops, "WRONG") 471 472 self.assertEqual("Requested default_output_alternative: WRONG, but " 473 "available output_alternatives are: []", str(e.exception)) 474 475 def test_get_output_alternatives_empty_provided_no_default(self): 476 prediction_tensor = constant_op.constant(["bogus"]) 477 model_fn_ops = model_fn.ModelFnOps( 478 model_fn.ModeKeys.INFER, 479 predictions={"some_output": prediction_tensor}, 480 output_alternatives={}) 481 482 output_alternatives, _ = saved_model_export_utils.get_output_alternatives( 483 model_fn_ops) 484 485 self.assertEqual({ 486 "default_output_alternative": (constants.ProblemType.UNSPECIFIED, { 487 "some_output": prediction_tensor 488 }) 489 }, output_alternatives) 490 491 def test_get_output_alternatives_implicit_single(self): 492 prediction_tensor = constant_op.constant(["bogus"]) 493 model_fn_ops = model_fn.ModelFnOps( 494 model_fn.ModeKeys.INFER, 495 predictions=prediction_tensor, 496 output_alternatives=None) 497 498 output_alternatives, _ = saved_model_export_utils.get_output_alternatives( 499 model_fn_ops) 500 self.assertEqual({ 501 "default_output_alternative": (constants.ProblemType.UNSPECIFIED, { 502 "output": prediction_tensor 503 }) 504 }, output_alternatives) 505 506 def test_build_all_signature_defs(self): 507 input_features = constant_op.constant(["10"]) 508 input_example = constant_op.constant(["input string"]) 509 input_ops = input_fn_utils.InputFnOps({ 510 "features": input_features 511 }, None, { 512 "default input": input_example 513 }) 514 input_alternatives, _ = ( 515 saved_model_export_utils.get_input_alternatives(input_ops)) 516 output_1 = constant_op.constant([1.0]) 517 output_2 = constant_op.constant(["2"]) 518 output_3 = constant_op.constant(["3"]) 519 provided_output_alternatives = { 520 "head-1": (constants.ProblemType.LINEAR_REGRESSION, { 521 "some_output_1": output_1 522 }), 523 "head-2": (constants.ProblemType.CLASSIFICATION, { 524 "some_output_2": output_2 525 }), 526 "head-3": (constants.ProblemType.UNSPECIFIED, { 527 "some_output_3": output_3 528 }), 529 } 530 model_fn_ops = model_fn.ModelFnOps( 531 model_fn.ModeKeys.INFER, 532 predictions={"some_output": constant_op.constant(["4"])}, 533 output_alternatives=provided_output_alternatives) 534 output_alternatives, _ = ( 535 saved_model_export_utils.get_output_alternatives( 536 model_fn_ops, "head-1")) 537 538 signature_defs = saved_model_export_utils.build_all_signature_defs( 539 input_alternatives, output_alternatives, "head-1") 540 541 expected_signature_defs = { 542 "serving_default": 543 signature_def_utils.regression_signature_def( 544 input_example, output_1), 545 "default_input_alternative:head-1": 546 signature_def_utils.regression_signature_def( 547 input_example, output_1), 548 "default_input_alternative:head-2": 549 signature_def_utils.classification_signature_def( 550 input_example, output_2, None), 551 "default_input_alternative:head-3": 552 signature_def_utils.predict_signature_def({ 553 "default input": input_example 554 }, { 555 "some_output_3": output_3 556 }), 557 # "features_input_alternative:head-1": 558 # signature_def_utils.regression_signature_def(input_features, 559 # output_1), 560 # "features_input_alternative:head-2": 561 # signature_def_utils.classification_signature_def(input_features, 562 # output_2, None), 563 # "features_input_alternative:head-3": 564 # signature_def_utils.predict_signature_def({ 565 # "input": input_features 566 # }, {"output": output_3}), 567 } 568 569 self.assertDictEqual(expected_signature_defs, signature_defs) 570 571 def test_build_all_signature_defs_legacy_input_fn_not_supported(self): 572 """Tests that legacy input_fn returning (features, labels) raises error. 573 574 serving_input_fn must return InputFnOps including a default input 575 alternative. 576 """ 577 input_features = constant_op.constant(["10"]) 578 input_ops = ({"features": input_features}, None) 579 input_alternatives, _ = ( 580 saved_model_export_utils.get_input_alternatives(input_ops)) 581 output_1 = constant_op.constant(["1"]) 582 output_2 = constant_op.constant(["2"]) 583 output_3 = constant_op.constant(["3"]) 584 provided_output_alternatives = { 585 "head-1": (constants.ProblemType.LINEAR_REGRESSION, { 586 "some_output_1": output_1 587 }), 588 "head-2": (constants.ProblemType.CLASSIFICATION, { 589 "some_output_2": output_2 590 }), 591 "head-3": (constants.ProblemType.UNSPECIFIED, { 592 "some_output_3": output_3 593 }), 594 } 595 model_fn_ops = model_fn.ModelFnOps( 596 model_fn.ModeKeys.INFER, 597 predictions={"some_output": constant_op.constant(["4"])}, 598 output_alternatives=provided_output_alternatives) 599 output_alternatives, _ = ( 600 saved_model_export_utils.get_output_alternatives( 601 model_fn_ops, "head-1")) 602 603 with self.assertRaisesRegexp( 604 ValueError, "A default input_alternative must be provided"): 605 saved_model_export_utils.build_all_signature_defs( 606 input_alternatives, output_alternatives, "head-1") 607 608 def test_get_timestamped_export_dir(self): 609 export_dir_base = tempfile.mkdtemp() + "export/" 610 export_dir_1 = saved_model_export_utils.get_timestamped_export_dir( 611 export_dir_base) 612 time.sleep(2) 613 export_dir_2 = saved_model_export_utils.get_timestamped_export_dir( 614 export_dir_base) 615 time.sleep(2) 616 export_dir_3 = saved_model_export_utils.get_timestamped_export_dir( 617 export_dir_base) 618 619 # Export directories should be named using a timestamp that is seconds 620 # since epoch. Such a timestamp is 10 digits long. 621 time_1 = os.path.basename(export_dir_1) 622 self.assertEqual(10, len(time_1)) 623 time_2 = os.path.basename(export_dir_2) 624 self.assertEqual(10, len(time_2)) 625 time_3 = os.path.basename(export_dir_3) 626 self.assertEqual(10, len(time_3)) 627 628 self.assertTrue(int(time_1) < int(time_2)) 629 self.assertTrue(int(time_2) < int(time_3)) 630 631 def test_garbage_collect_exports(self): 632 export_dir_base = tempfile.mkdtemp() + "export/" 633 gfile.MkDir(export_dir_base) 634 export_dir_1 = _create_test_export_dir(export_dir_base) 635 export_dir_2 = _create_test_export_dir(export_dir_base) 636 export_dir_3 = _create_test_export_dir(export_dir_base) 637 export_dir_4 = _create_test_export_dir(export_dir_base) 638 639 self.assertTrue(gfile.Exists(export_dir_1)) 640 self.assertTrue(gfile.Exists(export_dir_2)) 641 self.assertTrue(gfile.Exists(export_dir_3)) 642 self.assertTrue(gfile.Exists(export_dir_4)) 643 644 # Garbage collect all but the most recent 2 exports, 645 # where recency is determined based on the timestamp directory names. 646 saved_model_export_utils.garbage_collect_exports(export_dir_base, 2) 647 648 self.assertFalse(gfile.Exists(export_dir_1)) 649 self.assertFalse(gfile.Exists(export_dir_2)) 650 self.assertTrue(gfile.Exists(export_dir_3)) 651 self.assertTrue(gfile.Exists(export_dir_4)) 652 653 def test_get_most_recent_export(self): 654 export_dir_base = tempfile.mkdtemp() + "export/" 655 gfile.MkDir(export_dir_base) 656 _create_test_export_dir(export_dir_base) 657 _create_test_export_dir(export_dir_base) 658 _create_test_export_dir(export_dir_base) 659 export_dir_4 = _create_test_export_dir(export_dir_base) 660 661 (most_recent_export_dir, most_recent_export_version) = ( 662 saved_model_export_utils.get_most_recent_export(export_dir_base)) 663 664 self.assertEqual( 665 compat.as_bytes(export_dir_4), compat.as_bytes(most_recent_export_dir)) 666 self.assertEqual( 667 compat.as_bytes(export_dir_4), 668 os.path.join( 669 compat.as_bytes(export_dir_base), 670 compat.as_bytes(str(most_recent_export_version)))) 671 672 def test_make_export_strategy(self): 673 """Only tests that an ExportStrategy instance is created.""" 674 675 def _serving_input_fn(): 676 return array_ops.constant([1]), None 677 678 export_strategy = saved_model_export_utils.make_export_strategy( 679 serving_input_fn=_serving_input_fn, 680 default_output_alternative_key="default", 681 assets_extra={"from/path": "to/path"}, 682 as_text=False, 683 exports_to_keep=5) 684 self.assertTrue( 685 isinstance(export_strategy, export_strategy_lib.ExportStrategy)) 686 687 def test_make_parsing_export_strategy(self): 688 """Only tests that an ExportStrategy instance is created.""" 689 sparse_col = fc.sparse_column_with_hash_bucket( 690 "sparse_column", hash_bucket_size=100) 691 embedding_col = fc.embedding_column( 692 fc.sparse_column_with_hash_bucket( 693 "sparse_column_for_embedding", hash_bucket_size=10), 694 dimension=4) 695 real_valued_col1 = fc.real_valued_column("real_valued_column1") 696 bucketized_col1 = fc.bucketized_column( 697 fc.real_valued_column("real_valued_column_for_bucketization1"), [0, 4]) 698 feature_columns = [ 699 sparse_col, embedding_col, real_valued_col1, bucketized_col1 700 ] 701 702 export_strategy = saved_model_export_utils.make_parsing_export_strategy( 703 feature_columns=feature_columns) 704 self.assertTrue( 705 isinstance(export_strategy, export_strategy_lib.ExportStrategy)) 706 707 def test_make_best_model_export_strategy(self): 708 export_dir_base = tempfile.mkdtemp() + "export/" 709 gfile.MkDir(export_dir_base) 710 711 test_estimator = TestEstimator() 712 export_strategy = saved_model_export_utils.make_best_model_export_strategy( 713 serving_input_fn=None, exports_to_keep=3, compare_fn=None) 714 715 self.assertNotEqual("", 716 export_strategy.export(test_estimator, export_dir_base, 717 "fake_ckpt_0", { 718 "loss": 100 719 })) 720 self.assertNotEqual("", test_estimator.last_exported_dir) 721 self.assertNotEqual("", test_estimator.last_exported_checkpoint) 722 723 self.assertEqual("", 724 export_strategy.export(test_estimator, export_dir_base, 725 "fake_ckpt_1", { 726 "loss": 101 727 })) 728 self.assertEqual(test_estimator.last_exported_dir, 729 os.path.join(export_dir_base, "fake_ckpt_0")) 730 731 self.assertNotEqual("", 732 export_strategy.export(test_estimator, export_dir_base, 733 "fake_ckpt_2", { 734 "loss": 10 735 })) 736 self.assertEqual(test_estimator.last_exported_dir, 737 os.path.join(export_dir_base, "fake_ckpt_2")) 738 739 self.assertEqual("", 740 export_strategy.export(test_estimator, export_dir_base, 741 "fake_ckpt_3", { 742 "loss": 20 743 })) 744 self.assertEqual(test_estimator.last_exported_dir, 745 os.path.join(export_dir_base, "fake_ckpt_2")) 746 747 def test_make_best_model_export_strategy_with_preemption(self): 748 model_dir = self.get_temp_dir() 749 eval_dir_base = os.path.join(model_dir, "eval_continuous") 750 core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 50}, 1) 751 core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2) 752 753 test_estimator = TestEstimator() 754 export_strategy = saved_model_export_utils.make_best_model_export_strategy( 755 serving_input_fn=None, 756 exports_to_keep=3, 757 model_dir=model_dir, 758 event_file_pattern="eval_continuous/*.tfevents.*", 759 compare_fn=None) 760 761 export_dir_base = os.path.join(self.get_temp_dir(), "export") 762 self.assertEqual("", 763 export_strategy.export(test_estimator, export_dir_base, 764 "fake_ckpt_0", { 765 "loss": 100 766 })) 767 self.assertEqual("", test_estimator.last_exported_dir) 768 self.assertEqual("", test_estimator.last_exported_checkpoint) 769 770 self.assertNotEqual("", 771 export_strategy.export(test_estimator, export_dir_base, 772 "fake_ckpt_2", { 773 "loss": 10 774 })) 775 self.assertEqual(test_estimator.last_exported_dir, 776 os.path.join(export_dir_base, "fake_ckpt_2")) 777 778 self.assertEqual("", 779 export_strategy.export(test_estimator, export_dir_base, 780 "fake_ckpt_3", { 781 "loss": 20 782 })) 783 self.assertEqual(test_estimator.last_exported_dir, 784 os.path.join(export_dir_base, "fake_ckpt_2")) 785 786 def test_make_best_model_export_strategy_exceptions(self): 787 export_dir_base = tempfile.mkdtemp() + "export/" 788 789 test_estimator = TestEstimator() 790 export_strategy = saved_model_export_utils.make_best_model_export_strategy( 791 serving_input_fn=None, exports_to_keep=3, compare_fn=None) 792 793 with self.assertRaises(ValueError): 794 export_strategy.export(test_estimator, export_dir_base, "", {"loss": 200}) 795 796 with self.assertRaises(ValueError): 797 export_strategy.export(test_estimator, export_dir_base, "fake_ckpt_1", 798 None) 799 800 def test_extend_export_strategy(self): 801 802 def _base_export_fn(unused_estimator, 803 export_dir_base, 804 unused_checkpoint_path=None): 805 base_path = os.path.join(export_dir_base, "e1") 806 gfile.MkDir(base_path) 807 return base_path 808 809 def _post_export_fn(orig_path, new_path): 810 assert orig_path.endswith("/e1") 811 post_export_path = os.path.join(new_path, "rewrite") 812 gfile.MkDir(post_export_path) 813 return post_export_path 814 815 base_export_strategy = export_strategy_lib.ExportStrategy( 816 "Servo", _base_export_fn) 817 818 final_export_strategy = saved_model_export_utils.extend_export_strategy( 819 base_export_strategy, _post_export_fn, "Servo2") 820 self.assertEqual(final_export_strategy.name, "Servo2") 821 822 test_estimator = TestEstimator() 823 tmpdir = tempfile.mkdtemp() 824 export_model_dir = os.path.join(tmpdir, "model") 825 checkpoint_path = os.path.join(tmpdir, "checkpoint") 826 final_path = final_export_strategy.export(test_estimator, export_model_dir, 827 checkpoint_path) 828 self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) 829 830 def test_extend_export_strategy_same_name(self): 831 832 def _base_export_fn(unused_estimator, 833 export_dir_base, 834 unused_checkpoint_path=None): 835 base_path = os.path.join(export_dir_base, "e1") 836 gfile.MkDir(base_path) 837 return base_path 838 839 def _post_export_fn(orig_path, new_path): 840 assert orig_path.endswith("/e1") 841 post_export_path = os.path.join(new_path, "rewrite") 842 gfile.MkDir(post_export_path) 843 return post_export_path 844 845 base_export_strategy = export_strategy_lib.ExportStrategy( 846 "Servo", _base_export_fn) 847 848 final_export_strategy = saved_model_export_utils.extend_export_strategy( 849 base_export_strategy, _post_export_fn) 850 self.assertEqual(final_export_strategy.name, "Servo") 851 852 test_estimator = TestEstimator() 853 tmpdir = tempfile.mkdtemp() 854 export_model_dir = os.path.join(tmpdir, "model") 855 checkpoint_path = os.path.join(tmpdir, "checkpoint") 856 final_path = final_export_strategy.export(test_estimator, export_model_dir, 857 checkpoint_path) 858 self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) 859 860 def test_extend_export_strategy_raises_error(self): 861 862 def _base_export_fn(unused_estimator, 863 export_dir_base, 864 unused_checkpoint_path=None): 865 base_path = os.path.join(export_dir_base, "e1") 866 gfile.MkDir(base_path) 867 return base_path 868 869 def _post_export_fn(unused_orig_path, unused_new_path): 870 return tempfile.mkdtemp() 871 872 base_export_strategy = export_strategy_lib.ExportStrategy( 873 "Servo", _base_export_fn) 874 875 final_export_strategy = saved_model_export_utils.extend_export_strategy( 876 base_export_strategy, _post_export_fn) 877 878 test_estimator = TestEstimator() 879 tmpdir = tempfile.mkdtemp() 880 with self.assertRaises(ValueError) as ve: 881 final_export_strategy.export(test_estimator, tmpdir, 882 os.path.join(tmpdir, "checkpoint")) 883 884 self.assertTrue( 885 "post_export_fn must return a sub-directory" in str(ve.exception)) 886 887 888def _create_test_export_dir(export_dir_base): 889 export_dir = saved_model_export_utils.get_timestamped_export_dir( 890 export_dir_base) 891 gfile.MkDir(export_dir) 892 time.sleep(2) 893 return export_dir 894 895 896if __name__ == "__main__": 897 test.main() 898