• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/signature/signature_def_util.h"
16 
17 #include <gtest/gtest.h>
18 #include "tensorflow/cc/saved_model/signature_constants.h"
19 #include "tensorflow/core/platform/errors.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/lite/c/c_api.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/model_builder.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 #include "tensorflow/lite/testing/util.h"
26 
27 namespace tflite {
28 namespace {
29 
30 using tensorflow::kClassifyMethodName;
31 using tensorflow::kDefaultServingSignatureDefKey;
32 using tensorflow::kPredictMethodName;
33 using tensorflow::SignatureDef;
34 using tensorflow::Status;
35 
36 constexpr char kSignatureInput[] = "input";
37 constexpr char kSignatureOutput[] = "output";
38 constexpr char kTestFilePath[] = "tensorflow/lite/testdata/add.bin";
39 
40 class SimpleSignatureDefUtilTest : public testing::Test {
41  protected:
SetUp()42   void SetUp() override {
43     flatbuffer_model_ = FlatBufferModel::BuildFromFile(kTestFilePath);
44     ASSERT_NE(flatbuffer_model_, nullptr);
45     model_ = flatbuffer_model_->GetModel();
46     ASSERT_NE(model_, nullptr);
47   }
48 
GetTestSignatureDef()49   SignatureDef GetTestSignatureDef() {
50     auto signature_def = SignatureDef();
51     tensorflow::TensorInfo input_tensor;
52     tensorflow::TensorInfo output_tensor;
53     *input_tensor.mutable_name() = kSignatureInput;
54     *output_tensor.mutable_name() = kSignatureOutput;
55     *signature_def.mutable_method_name() = kClassifyMethodName;
56     (*signature_def.mutable_inputs())[kSignatureInput] = input_tensor;
57     (*signature_def.mutable_outputs())[kSignatureOutput] = output_tensor;
58     return signature_def;
59   }
60   std::unique_ptr<FlatBufferModel> flatbuffer_model_;
61   const Model* model_;
62 };
63 
TEST_F(SimpleSignatureDefUtilTest,SetSignatureDefTest)64 TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefTest) {
65   SignatureDef expected_signature_def = GetTestSignatureDef();
66   std::string model_output;
67   const std::map<string, SignatureDef> expected_signature_def_map = {
68       {kDefaultServingSignatureDefKey, expected_signature_def}};
69   EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
70                                              &model_output));
71   const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
72   EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
73   std::map<string, SignatureDef> test_signature_def_map;
74   EXPECT_EQ(Status::OK(),
75             GetSignatureDefMap(add_model, &test_signature_def_map));
76   SignatureDef test_signature_def =
77       test_signature_def_map[kDefaultServingSignatureDefKey];
78   EXPECT_EQ(expected_signature_def.SerializeAsString(),
79             test_signature_def.SerializeAsString());
80 }
81 
TEST_F(SimpleSignatureDefUtilTest,OverwriteSignatureDefTest)82 TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) {
83   auto expected_signature_def = GetTestSignatureDef();
84   std::string model_output;
85   std::map<string, SignatureDef> expected_signature_def_map = {
86       {kDefaultServingSignatureDefKey, expected_signature_def}};
87   EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
88                                              &model_output));
89   const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
90   EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
91   std::map<string, SignatureDef> test_signature_def_map;
92   EXPECT_EQ(Status::OK(),
93             GetSignatureDefMap(add_model, &test_signature_def_map));
94   SignatureDef test_signature_def =
95       test_signature_def_map[kDefaultServingSignatureDefKey];
96   EXPECT_EQ(expected_signature_def.SerializeAsString(),
97             test_signature_def.SerializeAsString());
98   *expected_signature_def.mutable_method_name() = kPredictMethodName;
99   expected_signature_def_map.erase(
100       expected_signature_def_map.find(kDefaultServingSignatureDefKey));
101   constexpr char kTestSignatureDefKey[] = "ServingTest";
102   expected_signature_def_map[kTestSignatureDefKey] = expected_signature_def;
103   EXPECT_EQ(
104       Status::OK(),
105       SetSignatureDefMap(add_model, expected_signature_def_map, &model_output));
106   const Model* final_model = flatbuffers::GetRoot<Model>(model_output.data());
107   EXPECT_FALSE(HasSignatureDef(final_model, kDefaultServingSignatureDefKey));
108   EXPECT_EQ(Status::OK(),
109             GetSignatureDefMap(final_model, &test_signature_def_map));
110   EXPECT_NE(expected_signature_def.SerializeAsString(),
111             test_signature_def.SerializeAsString());
112   EXPECT_TRUE(HasSignatureDef(final_model, kTestSignatureDefKey));
113   EXPECT_EQ(Status::OK(),
114             GetSignatureDefMap(final_model, &test_signature_def_map));
115   test_signature_def = test_signature_def_map[kTestSignatureDefKey];
116   EXPECT_EQ(expected_signature_def.SerializeAsString(),
117             test_signature_def.SerializeAsString());
118 }
119 
TEST_F(SimpleSignatureDefUtilTest,GetSignatureDefTest)120 TEST_F(SimpleSignatureDefUtilTest, GetSignatureDefTest) {
121   std::map<string, SignatureDef> test_signature_def_map;
122   EXPECT_EQ(Status::OK(), GetSignatureDefMap(model_, &test_signature_def_map));
123   EXPECT_FALSE(HasSignatureDef(model_, kDefaultServingSignatureDefKey));
124 }
125 
TEST_F(SimpleSignatureDefUtilTest,ClearSignatureDefTest)126 TEST_F(SimpleSignatureDefUtilTest, ClearSignatureDefTest) {
127   const int expected_num_buffers = model_->buffers()->size();
128   auto expected_signature_def = GetTestSignatureDef();
129   std::string model_output;
130   std::map<string, SignatureDef> expected_signature_def_map = {
131       {kDefaultServingSignatureDefKey, expected_signature_def}};
132   EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
133                                              &model_output));
134   const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
135   EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
136   SignatureDef test_signature_def;
137   std::map<string, SignatureDef> test_signature_def_map;
138   EXPECT_EQ(Status::OK(),
139             GetSignatureDefMap(add_model, &test_signature_def_map));
140   test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey];
141   EXPECT_EQ(expected_signature_def.SerializeAsString(),
142             test_signature_def.SerializeAsString());
143   EXPECT_EQ(Status::OK(), ClearSignatureDefMap(add_model, &model_output));
144   const Model* clear_model = flatbuffers::GetRoot<Model>(model_output.data());
145   EXPECT_FALSE(HasSignatureDef(clear_model, kDefaultServingSignatureDefKey));
146   EXPECT_EQ(expected_num_buffers, clear_model->buffers()->size());
147 }
148 
TEST_F(SimpleSignatureDefUtilTest,SetSignatureDefErrorsTest)149 TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefErrorsTest) {
150   std::map<string, SignatureDef> test_signature_def_map;
151   std::string model_output;
152   EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
153       SetSignatureDefMap(model_, test_signature_def_map, &model_output)));
154   SignatureDef test_signature_def;
155   test_signature_def_map[kDefaultServingSignatureDefKey] = test_signature_def;
156   EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
157       SetSignatureDefMap(model_, test_signature_def_map, nullptr)));
158 }
159 
160 }  // namespace
161 }  // namespace tflite
162