1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/signature/signature_def_util.h"
16
17 #include <gtest/gtest.h>
18 #include "tensorflow/cc/saved_model/signature_constants.h"
19 #include "tensorflow/core/platform/errors.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/lite/c/c_api.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/model_builder.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 #include "tensorflow/lite/testing/util.h"
26
27 namespace tflite {
28 namespace {
29
30 using tensorflow::kClassifyMethodName;
31 using tensorflow::kDefaultServingSignatureDefKey;
32 using tensorflow::kPredictMethodName;
33 using tensorflow::SignatureDef;
34 using tensorflow::Status;
35
36 constexpr char kSignatureInput[] = "input";
37 constexpr char kSignatureOutput[] = "output";
38 constexpr char kTestFilePath[] = "tensorflow/lite/testdata/add.bin";
39
40 class SimpleSignatureDefUtilTest : public testing::Test {
41 protected:
SetUp()42 void SetUp() override {
43 flatbuffer_model_ = FlatBufferModel::BuildFromFile(kTestFilePath);
44 ASSERT_NE(flatbuffer_model_, nullptr);
45 model_ = flatbuffer_model_->GetModel();
46 ASSERT_NE(model_, nullptr);
47 }
48
GetTestSignatureDef()49 SignatureDef GetTestSignatureDef() {
50 auto signature_def = SignatureDef();
51 tensorflow::TensorInfo input_tensor;
52 tensorflow::TensorInfo output_tensor;
53 *input_tensor.mutable_name() = kSignatureInput;
54 *output_tensor.mutable_name() = kSignatureOutput;
55 *signature_def.mutable_method_name() = kClassifyMethodName;
56 (*signature_def.mutable_inputs())[kSignatureInput] = input_tensor;
57 (*signature_def.mutable_outputs())[kSignatureOutput] = output_tensor;
58 return signature_def;
59 }
60 std::unique_ptr<FlatBufferModel> flatbuffer_model_;
61 const Model* model_;
62 };
63
TEST_F(SimpleSignatureDefUtilTest,SetSignatureDefTest)64 TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefTest) {
65 SignatureDef expected_signature_def = GetTestSignatureDef();
66 std::string model_output;
67 const std::map<string, SignatureDef> expected_signature_def_map = {
68 {kDefaultServingSignatureDefKey, expected_signature_def}};
69 EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
70 &model_output));
71 const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
72 EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
73 std::map<string, SignatureDef> test_signature_def_map;
74 EXPECT_EQ(Status::OK(),
75 GetSignatureDefMap(add_model, &test_signature_def_map));
76 SignatureDef test_signature_def =
77 test_signature_def_map[kDefaultServingSignatureDefKey];
78 EXPECT_EQ(expected_signature_def.SerializeAsString(),
79 test_signature_def.SerializeAsString());
80 }
81
TEST_F(SimpleSignatureDefUtilTest,OverwriteSignatureDefTest)82 TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) {
83 auto expected_signature_def = GetTestSignatureDef();
84 std::string model_output;
85 std::map<string, SignatureDef> expected_signature_def_map = {
86 {kDefaultServingSignatureDefKey, expected_signature_def}};
87 EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
88 &model_output));
89 const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
90 EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
91 std::map<string, SignatureDef> test_signature_def_map;
92 EXPECT_EQ(Status::OK(),
93 GetSignatureDefMap(add_model, &test_signature_def_map));
94 SignatureDef test_signature_def =
95 test_signature_def_map[kDefaultServingSignatureDefKey];
96 EXPECT_EQ(expected_signature_def.SerializeAsString(),
97 test_signature_def.SerializeAsString());
98 *expected_signature_def.mutable_method_name() = kPredictMethodName;
99 expected_signature_def_map.erase(
100 expected_signature_def_map.find(kDefaultServingSignatureDefKey));
101 constexpr char kTestSignatureDefKey[] = "ServingTest";
102 expected_signature_def_map[kTestSignatureDefKey] = expected_signature_def;
103 EXPECT_EQ(
104 Status::OK(),
105 SetSignatureDefMap(add_model, expected_signature_def_map, &model_output));
106 const Model* final_model = flatbuffers::GetRoot<Model>(model_output.data());
107 EXPECT_FALSE(HasSignatureDef(final_model, kDefaultServingSignatureDefKey));
108 EXPECT_EQ(Status::OK(),
109 GetSignatureDefMap(final_model, &test_signature_def_map));
110 EXPECT_NE(expected_signature_def.SerializeAsString(),
111 test_signature_def.SerializeAsString());
112 EXPECT_TRUE(HasSignatureDef(final_model, kTestSignatureDefKey));
113 EXPECT_EQ(Status::OK(),
114 GetSignatureDefMap(final_model, &test_signature_def_map));
115 test_signature_def = test_signature_def_map[kTestSignatureDefKey];
116 EXPECT_EQ(expected_signature_def.SerializeAsString(),
117 test_signature_def.SerializeAsString());
118 }
119
TEST_F(SimpleSignatureDefUtilTest,GetSignatureDefTest)120 TEST_F(SimpleSignatureDefUtilTest, GetSignatureDefTest) {
121 std::map<string, SignatureDef> test_signature_def_map;
122 EXPECT_EQ(Status::OK(), GetSignatureDefMap(model_, &test_signature_def_map));
123 EXPECT_FALSE(HasSignatureDef(model_, kDefaultServingSignatureDefKey));
124 }
125
TEST_F(SimpleSignatureDefUtilTest,ClearSignatureDefTest)126 TEST_F(SimpleSignatureDefUtilTest, ClearSignatureDefTest) {
127 const int expected_num_buffers = model_->buffers()->size();
128 auto expected_signature_def = GetTestSignatureDef();
129 std::string model_output;
130 std::map<string, SignatureDef> expected_signature_def_map = {
131 {kDefaultServingSignatureDefKey, expected_signature_def}};
132 EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
133 &model_output));
134 const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
135 EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
136 SignatureDef test_signature_def;
137 std::map<string, SignatureDef> test_signature_def_map;
138 EXPECT_EQ(Status::OK(),
139 GetSignatureDefMap(add_model, &test_signature_def_map));
140 test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey];
141 EXPECT_EQ(expected_signature_def.SerializeAsString(),
142 test_signature_def.SerializeAsString());
143 EXPECT_EQ(Status::OK(), ClearSignatureDefMap(add_model, &model_output));
144 const Model* clear_model = flatbuffers::GetRoot<Model>(model_output.data());
145 EXPECT_FALSE(HasSignatureDef(clear_model, kDefaultServingSignatureDefKey));
146 EXPECT_EQ(expected_num_buffers, clear_model->buffers()->size());
147 }
148
TEST_F(SimpleSignatureDefUtilTest,SetSignatureDefErrorsTest)149 TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefErrorsTest) {
150 std::map<string, SignatureDef> test_signature_def_map;
151 std::string model_output;
152 EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
153 SetSignatureDefMap(model_, test_signature_def_map, &model_output)));
154 SignatureDef test_signature_def;
155 test_signature_def_map[kDefaultServingSignatureDefKey] = test_signature_def;
156 EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
157 SetSignatureDefMap(model_, test_signature_def_map, nullptr)));
158 }
159
160 } // namespace
161 } // namespace tflite
162