1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/session_bundle/bundle_shim.h"
17
18 #include "google/protobuf/any.pb.h"
19 #include "tensorflow/cc/saved_model/signature_constants.h"
20 #include "tensorflow/cc/saved_model/tag_constants.h"
21 #include "tensorflow/contrib/session_bundle/test_util.h"
22 #include "tensorflow/core/example/example.pb.h"
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/framework/tensor_testutil.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28
29 namespace tensorflow {
30 namespace serving {
31 namespace internal {
32 namespace {
33
34 constexpr char kSessionBundlePath[] =
35 "session_bundle/testdata/half_plus_two/00000123";
36 constexpr char kSavedModelBundlePath[] =
37 "cc/saved_model/testdata/half_plus_two/00000123";
38
MakeSerializedExample(float x)39 string MakeSerializedExample(float x) {
40 tensorflow::Example example;
41 auto* feature_map = example.mutable_features()->mutable_feature();
42 (*feature_map)["x"].mutable_float_list()->add_value(x);
43 return example.SerializeAsString();
44 }
45
ValidateHalfPlusTwo(const SavedModelBundle & saved_model_bundle,const string & input_tensor_name,const string & output_tensor_name)46 void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
47 const string& input_tensor_name,
48 const string& output_tensor_name) {
49 // Validate the half plus two behavior.
50 std::vector<string> serialized_examples;
51 for (float x : {0, 1, 2, 3}) {
52 serialized_examples.push_back(MakeSerializedExample(x));
53 }
54 Tensor input = test::AsTensor<string>(serialized_examples, TensorShape({4}));
55
56 std::vector<Tensor> outputs;
57 TF_ASSERT_OK(saved_model_bundle.session->Run(
58 {{input_tensor_name, input}}, {output_tensor_name}, {}, &outputs));
59 ASSERT_EQ(outputs.size(), 1);
60 test::ExpectTensorEqual<float>(
61 outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
62 }
63
LoadAndValidateSavedModelBundle(const string & export_dir,const std::unordered_set<string> & tags,const string & signature_def_key,bool expect_session_bundle)64 void LoadAndValidateSavedModelBundle(const string& export_dir,
65 const std::unordered_set<string>& tags,
66 const string& signature_def_key,
67 bool expect_session_bundle) {
68 SessionOptions session_options;
69 RunOptions run_options;
70 SavedModelBundle saved_model_bundle;
71 bool is_session_bundle = false;
72 TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
73 session_options, run_options, export_dir, tags, &saved_model_bundle,
74 &is_session_bundle));
75 EXPECT_EQ(expect_session_bundle, is_session_bundle);
76 const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
77 const auto& signature_def_map = meta_graph_def.signature_def();
78
79 const auto& regression_entry = signature_def_map.find(signature_def_key);
80 ASSERT_FALSE(regression_entry == signature_def_map.end());
81 SignatureDef regression_signature_def = regression_entry->second;
82
83 EXPECT_EQ(1, regression_signature_def.inputs_size());
84 ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) ==
85 regression_signature_def.inputs().end());
86 TensorInfo input_tensor_info =
87 regression_signature_def.inputs().find(kRegressInputs)->second;
88 EXPECT_EQ(1, regression_signature_def.outputs_size());
89 // Ensure the TensorInfo has dtype populated.
90 EXPECT_EQ(DT_STRING, input_tensor_info.dtype());
91
92 ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) ==
93 regression_signature_def.outputs().end());
94 TensorInfo output_tensor_info =
95 regression_signature_def.outputs().find(kRegressOutputs)->second;
96 // Ensure the TensorInfo has dtype populated.
97 EXPECT_EQ(DT_FLOAT, output_tensor_info.dtype());
98 ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(),
99 output_tensor_info.name());
100 }
101
102 // Helper function to validate that the SignatureDef found in the MetaGraphDef
103 // with the provided key has the expected string representation.
ValidateSignatureDef(const MetaGraphDef & meta_graph_def,const string & key,const string & expected_string_signature_def)104 void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key,
105 const string& expected_string_signature_def) {
106 tensorflow::SignatureDef expected_signature;
107 CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def,
108 &expected_signature));
109 auto iter = meta_graph_def.signature_def().find(key);
110 ASSERT_TRUE(iter != meta_graph_def.signature_def().end());
111 EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString());
112 }
113
114 // Checks that the input map in a signature def is populated correctly.
TEST(BundleShimTest,AddInputToSignatureDef)115 TEST(BundleShimTest, AddInputToSignatureDef) {
116 SignatureDef signature_def;
117 const string tensor_name = "foo_tensor";
118 const string map_key = "foo_key";
119
120 // Build a map of tensor-name to dtype, for the unit-test.
121 std::unordered_map<string, DataType> tensor_name_to_dtype;
122 tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
123
124 AddInputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
125 &signature_def);
126 EXPECT_EQ(1, signature_def.inputs_size());
127 EXPECT_EQ(tensor_name, signature_def.inputs().find(map_key)->second.name());
128 }
129
130 // Checks that the output map in a signature def is populated correctly.
TEST(BundleShimTest,AddOutputToSignatureDef)131 TEST(BundleShimTest, AddOutputToSignatureDef) {
132 SignatureDef signature_def;
133 const string tensor_name = "foo_tensor";
134 const string map_key = "foo_key";
135
136 // Build a map of tensor-name to dtype, for the unit-test.
137 std::unordered_map<string, DataType> tensor_name_to_dtype;
138 tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
139
140 AddOutputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
141 &signature_def);
142 EXPECT_EQ(1, signature_def.outputs_size());
143 EXPECT_EQ(tensor_name, signature_def.outputs().find(map_key)->second.name());
144 }
145
146 // Checks that no signature defs are added if the default signature is missing.
TEST(BundleShimTest,DefaultSignatureMissing)147 TEST(BundleShimTest, DefaultSignatureMissing) {
148 MetaGraphDef meta_graph_def;
149 // Signatures signatures;
150 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
151 EXPECT_EQ(0, meta_graph_def.signature_def_size());
152 }
153
154 // Checks that no signature defs are added if the default signature is empty.
TEST(BundleShimTest,DefaultSignatureEmpty)155 TEST(BundleShimTest, DefaultSignatureEmpty) {
156 Signatures signatures;
157 signatures.mutable_default_signature();
158
159 MetaGraphDef meta_graph_def;
160 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
161 .mutable_any_list()
162 ->add_value()
163 ->PackFrom(signatures);
164 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
165 EXPECT_EQ(0, meta_graph_def.signature_def_size());
166 }
167
168 // Checks the conversion to signature def for a regression default signature.
TEST(BundleShimTest,DefaultSignatureRegression)169 TEST(BundleShimTest, DefaultSignatureRegression) {
170 Signatures signatures;
171 RegressionSignature* regression_signature =
172 signatures.mutable_default_signature()->mutable_regression_signature();
173 regression_signature->mutable_input()->set_tensor_name("foo-input");
174 regression_signature->mutable_output()->set_tensor_name("foo-output");
175 MetaGraphDef meta_graph_def;
176 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
177 .mutable_any_list()
178 ->add_value()
179 ->PackFrom(signatures);
180 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
181 EXPECT_EQ(1, meta_graph_def.signature_def_size());
182 const auto actual_signature_def =
183 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
184 EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
185 .find(kRegressInputs)
186 ->second.name());
187 EXPECT_EQ("foo-output", actual_signature_def->second.outputs()
188 .find(kRegressOutputs)
189 ->second.name());
190 EXPECT_EQ(kRegressMethodName, actual_signature_def->second.method_name());
191 }
192
193 // Checks the conversion to signature def for a classification default
194 // signature.
TEST(BundleShimTest,DefaultSignatureClassification)195 TEST(BundleShimTest, DefaultSignatureClassification) {
196 Signatures signatures;
197 ClassificationSignature* classification_signature =
198 signatures.mutable_default_signature()
199 ->mutable_classification_signature();
200 classification_signature->mutable_input()->set_tensor_name("foo-input");
201 classification_signature->mutable_classes()->set_tensor_name("foo-classes");
202 classification_signature->mutable_scores()->set_tensor_name("foo-scores");
203 MetaGraphDef meta_graph_def;
204 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
205 .mutable_any_list()
206 ->add_value()
207 ->PackFrom(signatures);
208 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
209 EXPECT_EQ(1, meta_graph_def.signature_def_size());
210 const auto actual_signature_def =
211 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
212 EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
213 .find(kClassifyInputs)
214 ->second.name());
215 EXPECT_EQ("foo-classes", actual_signature_def->second.outputs()
216 .find(kClassifyOutputClasses)
217 ->second.name());
218 EXPECT_EQ("foo-scores", actual_signature_def->second.outputs()
219 .find(kClassifyOutputScores)
220 ->second.name());
221 EXPECT_EQ(kClassifyMethodName, actual_signature_def->second.method_name());
222 }
223
224 // Checks that generic default signatures are not up converted.
TEST(BundleShimTest,DefaultSignatureGeneric)225 TEST(BundleShimTest, DefaultSignatureGeneric) {
226 TensorBinding input_binding;
227 input_binding.set_tensor_name("foo-input");
228
229 TensorBinding output_binding;
230 output_binding.set_tensor_name("foo-output");
231
232 Signatures signatures;
233 GenericSignature* generic_signature =
234 signatures.mutable_default_signature()->mutable_generic_signature();
235 generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
236 generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
237
238 MetaGraphDef meta_graph_def;
239 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
240 .mutable_any_list()
241 ->add_value()
242 ->PackFrom(signatures);
243 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
244 EXPECT_EQ(0, meta_graph_def.signature_def_size());
245 }
246
TEST(BundleShimTest,NamedRegressionSignatures)247 TEST(BundleShimTest, NamedRegressionSignatures) {
248 Signatures signatures;
249
250 RegressionSignature* foo_regression_signature =
251 (*signatures.mutable_named_signatures())["foo"]
252 .mutable_regression_signature();
253 foo_regression_signature->mutable_input()->set_tensor_name("foo-input");
254 foo_regression_signature->mutable_output()->set_tensor_name("foo-output");
255
256 RegressionSignature* bar_regression_signature =
257 (*signatures.mutable_named_signatures())["bar"]
258 .mutable_regression_signature();
259 bar_regression_signature->mutable_input()->set_tensor_name("bar-input");
260 bar_regression_signature->mutable_output()->set_tensor_name("bar-output");
261
262 MetaGraphDef meta_graph_def;
263 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
264 .mutable_any_list()
265 ->add_value()
266 ->PackFrom(signatures);
267 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
268 ASSERT_EQ(2, meta_graph_def.signature_def_size());
269
270 ValidateSignatureDef(meta_graph_def, "foo",
271 "inputs { "
272 " key: \"inputs\" "
273 " value { "
274 "name: \"foo-input\" "
275 " } "
276 "} "
277 "outputs { "
278 " key: \"outputs\" "
279 " value { "
280 " name: \"foo-output\" "
281 " } "
282 "} "
283 "method_name: \"tensorflow/serving/regress\" ");
284 ValidateSignatureDef(meta_graph_def, "bar",
285 "inputs { "
286 " key: \"inputs\" "
287 " value { "
288 "name: \"bar-input\" "
289 " } "
290 "} "
291 "outputs { "
292 " key: \"outputs\" "
293 " value { "
294 " name: \"bar-output\" "
295 " } "
296 "} "
297 "method_name: \"tensorflow/serving/regress\" ");
298 }
299
TEST(BundleShimTest,NamedClassificationSignatures)300 TEST(BundleShimTest, NamedClassificationSignatures) {
301 Signatures signatures;
302
303 ClassificationSignature* foo_classification_signature =
304 (*signatures.mutable_named_signatures())["foo"]
305 .mutable_classification_signature();
306 foo_classification_signature->mutable_input()->set_tensor_name("foo-input");
307 foo_classification_signature->mutable_classes()->set_tensor_name(
308 "foo-classes");
309
310 ClassificationSignature* bar_classification_signature =
311 (*signatures.mutable_named_signatures())["bar"]
312 .mutable_classification_signature();
313 bar_classification_signature->mutable_input()->set_tensor_name("bar-input");
314 bar_classification_signature->mutable_scores()->set_tensor_name("bar-scores");
315
316 MetaGraphDef meta_graph_def;
317 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
318 .mutable_any_list()
319 ->add_value()
320 ->PackFrom(signatures);
321 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
322 ASSERT_EQ(2, meta_graph_def.signature_def_size());
323
324 ValidateSignatureDef(meta_graph_def, "foo",
325 "inputs { "
326 " key: \"inputs\" "
327 " value { "
328 "name: \"foo-input\" "
329 " } "
330 "} "
331 "outputs { "
332 " key: \"classes\" "
333 " value { "
334 " name: \"foo-classes\" "
335 " } "
336 "} "
337 "method_name: \"tensorflow/serving/classify\" ");
338 ValidateSignatureDef(meta_graph_def, "bar",
339 "inputs { "
340 " key: \"inputs\" "
341 " value { "
342 "name: \"bar-input\" "
343 " } "
344 "} "
345 "outputs { "
346 " key: \"scores\" "
347 " value { "
348 " name: \"bar-scores\" "
349 " } "
350 "} "
351 "method_name: \"tensorflow/serving/classify\" ");
352 }
353
354 // Checks the Predict SignatureDef created when the named signatures have
355 // `inputs` and `outputs`.
TEST(BundleShimTest,NamedSignatureGenericInputsAndOutputs)356 TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) {
357 TensorBinding input_binding;
358 input_binding.set_tensor_name("foo-input");
359
360 TensorBinding output_binding;
361 output_binding.set_tensor_name("foo-output");
362
363 Signatures signatures;
364 GenericSignature* input_generic_signature =
365 (*signatures.mutable_named_signatures())[kPredictInputs]
366 .mutable_generic_signature();
367 input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
368
369 GenericSignature* output_generic_signature =
370 (*signatures.mutable_named_signatures())[kPredictOutputs]
371 .mutable_generic_signature();
372 output_generic_signature->mutable_map()->insert(
373 {"foo-output", output_binding});
374
375 MetaGraphDef meta_graph_def;
376 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
377 .mutable_any_list()
378 ->add_value()
379 ->PackFrom(signatures);
380 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
381 EXPECT_EQ(1, meta_graph_def.signature_def_size());
382 const auto actual_signature_def =
383 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
384 ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end());
385 ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") ==
386 actual_signature_def->second.inputs().end());
387 EXPECT_EQ(
388 "foo-input",
389 actual_signature_def->second.inputs().find("foo-input")->second.name());
390 ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") ==
391 actual_signature_def->second.outputs().end());
392 EXPECT_EQ(
393 "foo-output",
394 actual_signature_def->second.outputs().find("foo-output")->second.name());
395 EXPECT_EQ(kPredictMethodName, actual_signature_def->second.method_name());
396 }
397
398 // Checks that a signature def is not added if the named signatures is generic
399 // but does not have `inputs` and `outputs`.
TEST(BundleShimTest,NamedSignatureGenericNoInputsOrOutputs)400 TEST(BundleShimTest, NamedSignatureGenericNoInputsOrOutputs) {
401 TensorBinding input_binding;
402 input_binding.set_tensor_name("foo-input");
403
404 TensorBinding output_binding;
405 output_binding.set_tensor_name("foo-output");
406
407 Signatures signatures;
408 GenericSignature* generic_signature =
409 (*signatures.mutable_named_signatures())["unknown"]
410 .mutable_generic_signature();
411 generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
412 generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
413
414 MetaGraphDef meta_graph_def;
415 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
416 .mutable_any_list()
417 ->add_value()
418 ->PackFrom(signatures);
419 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
420 EXPECT_EQ(0, meta_graph_def.signature_def_size());
421 }
422
423 // Checks that a signature def is not added when the named signatures have only
424 // one of `inputs` and `outputs`.
TEST(BundleShimTest,NamedSignatureGenericOnlyInput)425 TEST(BundleShimTest, NamedSignatureGenericOnlyInput) {
426 TensorBinding input_binding;
427 input_binding.set_tensor_name("foo-input");
428
429 Signatures signatures;
430 GenericSignature* input_generic_signature =
431 (*signatures.mutable_named_signatures())[kPredictInputs]
432 .mutable_generic_signature();
433 input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
434
435 MetaGraphDef meta_graph_def;
436 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
437 .mutable_any_list()
438 ->add_value()
439 ->PackFrom(signatures);
440 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
441 EXPECT_EQ(0, meta_graph_def.signature_def_size());
442 }
443
444 // Tests up-conversion of Signatures to SignatureDefs when both `default` and
445 // `named` signatures are present.
TEST(BundleShimTest,DefaultAndNamedSignatureWithPredict)446 TEST(BundleShimTest, DefaultAndNamedSignatureWithPredict) {
447 Signatures signatures;
448
449 // Build a generic signature corresponding to `inputs` and add it to the
450 // Signatures to up-convert.
451 TensorBinding input_binding;
452 input_binding.set_tensor_name("foo-input");
453 GenericSignature* input_generic_signature =
454 (*signatures.mutable_named_signatures())[kPredictInputs]
455 .mutable_generic_signature();
456 input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
457
458 // Build a generic signature corresponding to `outputs` and add it to the
459 // Signatures to up-convert.
460 TensorBinding output_binding;
461 output_binding.set_tensor_name("foo-output");
462 GenericSignature* output_generic_signature =
463 (*signatures.mutable_named_signatures())[kPredictOutputs]
464 .mutable_generic_signature();
465 output_generic_signature->mutable_map()->insert(
466 {"foo-output", output_binding});
467
468 // Build a regression signature and set it as the default signature.
469 RegressionSignature* inputs_regression_signature =
470 (*signatures.mutable_default_signature()).mutable_regression_signature();
471 inputs_regression_signature->mutable_input()->set_tensor_name("bar-input");
472
473 // Up-convert the available signatures to SignatureDefs.
474 MetaGraphDef meta_graph_def;
475 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
476 .mutable_any_list()
477 ->add_value()
478 ->PackFrom(signatures);
479 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
480 EXPECT_EQ(2, meta_graph_def.signature_def_size());
481
482 // Verify that the default regression signature is converted to a
483 // SignatureDef that corresponds to the kDefaultServingSignatureDefKey.
484 const auto actual_signature_def_regress =
485 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
486 ASSERT_FALSE(actual_signature_def_regress ==
487 meta_graph_def.signature_def().end());
488 ASSERT_FALSE(
489 actual_signature_def_regress->second.inputs().find(kRegressInputs) ==
490 actual_signature_def_regress->second.inputs().end());
491
492 // Verify that the `Predict` SignatureDef is created under a different key.
493 const auto actual_signature_def_predict = meta_graph_def.signature_def().find(
494 strings::StrCat(kDefaultServingSignatureDefKey, "_from_named"));
495 ASSERT_FALSE(actual_signature_def_predict ==
496 meta_graph_def.signature_def().end());
497 ASSERT_FALSE(
498 actual_signature_def_predict->second.inputs().find("foo-input") ==
499 actual_signature_def_predict->second.inputs().end());
500 EXPECT_EQ("foo-input", actual_signature_def_predict->second.inputs()
501 .find("foo-input")
502 ->second.name());
503 ASSERT_FALSE(
504 actual_signature_def_predict->second.outputs().find("foo-output") ==
505 actual_signature_def_predict->second.outputs().end());
506 EXPECT_EQ("foo-output", actual_signature_def_predict->second.outputs()
507 .find("foo-output")
508 ->second.name());
509 EXPECT_EQ(kPredictMethodName,
510 actual_signature_def_predict->second.method_name());
511 }
512
513 // Checks a basic up conversion for half plus two for SessionBundle.
TEST(BundleShimTest,BasicExportSessionBundle)514 TEST(BundleShimTest, BasicExportSessionBundle) {
515 const std::unordered_set<string> tags = {"tag"};
516 const string session_bundle_export_dir =
517 test_util::TestSrcDirPath(kSessionBundlePath);
518 LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
519 kDefaultServingSignatureDefKey,
520 /*expect_session_bundle=*/true);
521
522 // Verify that the named signature is also present.
523 SessionOptions session_options;
524 RunOptions run_options;
525 SavedModelBundle saved_model_bundle;
526 TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
527 session_bundle_export_dir,
528 tags, &saved_model_bundle));
529 const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
530 const auto& signature_def_map = meta_graph_def.signature_def();
531 bool found_named_signature = false;
532 for (const auto& entry : signature_def_map) {
533 const string& key = entry.first;
534 const SignatureDef& signature_def = entry.second;
535
536 // We're looking for the key that is *not* kDefaultServingSignatureDefKey.
537 if (key == kDefaultServingSignatureDefKey) {
538 continue;
539 }
540 found_named_signature = true;
541
542 EXPECT_EQ(1, signature_def.inputs_size());
543 const auto it_inputs_x = signature_def.inputs().find("x");
544 EXPECT_FALSE(it_inputs_x == signature_def.inputs().end());
545 // Ensure the TensorInfo has name and dtype populated.
546 const TensorInfo& tensor_info_x = it_inputs_x->second;
547 EXPECT_EQ("x:0", tensor_info_x.name());
548 EXPECT_EQ(DT_FLOAT, tensor_info_x.dtype());
549
550 EXPECT_EQ(1, signature_def.outputs_size());
551 const auto it_outputs_y = signature_def.outputs().find("y");
552 EXPECT_FALSE(it_outputs_y == signature_def.outputs().end());
553 // Ensure the TensorInfo has name and dtype populated.
554 const TensorInfo& tensor_info_y = it_outputs_y->second;
555 EXPECT_EQ("y:0", tensor_info_y.name());
556 EXPECT_EQ(DT_FLOAT, tensor_info_y.dtype());
557 }
558 EXPECT_TRUE(found_named_signature);
559 }
560
561 // Checks a basic load for half plus two for SavedModelBundle.
TEST(BundleShimTest,BasicExportSavedModel)562 TEST(BundleShimTest, BasicExportSavedModel) {
563 const string saved_model_bundle_export_dir =
564 io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
565 LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
566 {kSavedModelTagServe}, "regress_x_to_y",
567 /*expect_session_bundle=*/false);
568 }
569
570 // Checks a basic load fails with an invalid export path.
TEST(BundleShimTest,InvalidPath)571 TEST(BundleShimTest, InvalidPath) {
572 const string invalid_export_dir = testing::TensorFlowSrcRoot();
573 SessionOptions session_options;
574 RunOptions run_options;
575 SavedModelBundle saved_model_bundle;
576 Status status = LoadSessionBundleOrSavedModelBundle(
577 session_options, run_options, invalid_export_dir, {kSavedModelTagServe},
578 &saved_model_bundle);
579 EXPECT_EQ(error::Code::NOT_FOUND, status.code());
580 }
581
582 // Checks that if loading a session bundle fails, the error is propagated to
583 // LoadSessionBundleOrSavedModelBundle().
TEST(BundleShimTest,LoadSessionBundleError)584 TEST(BundleShimTest, LoadSessionBundleError) {
585 const string session_bundle_export_dir =
586 test_util::TestSrcDirPath(kSessionBundlePath);
587 SessionOptions session_options;
588 RunOptions run_options;
589 // Invalid threadpool index to use for session-run calls.
590 run_options.set_inter_op_thread_pool(100);
591 SavedModelBundle saved_model_bundle;
592 EXPECT_FALSE(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
593 session_bundle_export_dir,
594 {"tag"}, &saved_model_bundle)
595 .ok());
596 }
597
598 } // namespace
599 } // namespace internal
600 } // namespace serving
601 } // namespace tensorflow
602