1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/mutable_op_resolver.h"
17
18 #include <stddef.h>
19
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23 #include "tensorflow/lite/testing/util.h"
24
25 namespace tflite {
26 namespace {
27
28 // We need some dummy functions to identify the registrations.
DummyInvoke(TfLiteContext * context,TfLiteNode * node)29 TfLiteStatus DummyInvoke(TfLiteContext* context, TfLiteNode* node) {
30 return kTfLiteOk;
31 }
32
GetDummyRegistration()33 TfLiteRegistration* GetDummyRegistration() {
34 static TfLiteRegistration registration = {
35 .init = nullptr,
36 .free = nullptr,
37 .prepare = nullptr,
38 .invoke = DummyInvoke,
39 };
40 return ®istration;
41 }
42
Dummy2Invoke(TfLiteContext * context,TfLiteNode * node)43 TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
44 return kTfLiteOk;
45 }
46
Dummy2Prepare(TfLiteContext * context,TfLiteNode * node)47 TfLiteStatus Dummy2Prepare(TfLiteContext* context, TfLiteNode* node) {
48 return kTfLiteOk;
49 }
50
Dummy2Init(TfLiteContext * context,const char * buffer,size_t length)51 void* Dummy2Init(TfLiteContext* context, const char* buffer, size_t length) {
52 return nullptr;
53 }
54
Dummy2free(TfLiteContext * context,void * buffer)55 void Dummy2free(TfLiteContext* context, void* buffer) {}
56
GetDummy2Registration()57 TfLiteRegistration* GetDummy2Registration() {
58 static TfLiteRegistration registration = {
59 .init = Dummy2Init,
60 .free = Dummy2free,
61 .prepare = Dummy2Prepare,
62 .invoke = Dummy2Invoke,
63 };
64 return ®istration;
65 }
66
TEST(MutableOpResolverTest,FinOp)67 TEST(MutableOpResolverTest, FinOp) {
68 MutableOpResolver resolver;
69 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
70
71 const TfLiteRegistration* found_registration =
72 resolver.FindOp(BuiltinOperator_ADD, 1);
73 ASSERT_NE(found_registration, nullptr);
74 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
75 EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_ADD);
76 EXPECT_EQ(found_registration->version, 1);
77 }
78
TEST(MutableOpResolverTest,FindMissingOp)79 TEST(MutableOpResolverTest, FindMissingOp) {
80 MutableOpResolver resolver;
81 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
82
83 const TfLiteRegistration* found_registration =
84 resolver.FindOp(BuiltinOperator_CONV_2D, 1);
85 EXPECT_EQ(found_registration, nullptr);
86 }
87
TEST(MutableOpResolverTest,RegisterOpWithSingleVersion)88 TEST(MutableOpResolverTest, RegisterOpWithSingleVersion) {
89 MutableOpResolver resolver;
90 // The kernel supports version 2 only
91 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2);
92
93 const TfLiteRegistration* found_registration;
94
95 found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
96 ASSERT_EQ(found_registration, nullptr);
97
98 found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
99 ASSERT_NE(found_registration, nullptr);
100 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
101 EXPECT_EQ(found_registration->version, 2);
102
103 found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
104 ASSERT_EQ(found_registration, nullptr);
105 }
106
TEST(MutableOpResolverTest,RegisterOpWithMultipleVersions)107 TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) {
108 MutableOpResolver resolver;
109 // The kernel supports version 2 and 3
110 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
111
112 const TfLiteRegistration* found_registration;
113
114 found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
115 ASSERT_NE(found_registration, nullptr);
116 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
117 EXPECT_EQ(found_registration->version, 2);
118
119 found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
120 ASSERT_NE(found_registration, nullptr);
121 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
122 EXPECT_EQ(found_registration->version, 3);
123 }
124
TEST(MutableOpResolverTest,FindOpWithUnsupportedVersions)125 TEST(MutableOpResolverTest, FindOpWithUnsupportedVersions) {
126 MutableOpResolver resolver;
127 // The kernel supports version 2 and 3
128 resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
129
130 const TfLiteRegistration* found_registration;
131
132 found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
133 EXPECT_EQ(found_registration, nullptr);
134
135 found_registration = resolver.FindOp(BuiltinOperator_ADD, 4);
136 EXPECT_EQ(found_registration, nullptr);
137 }
138
TEST(MutableOpResolverTest,FindCustomOp)139 TEST(MutableOpResolverTest, FindCustomOp) {
140 MutableOpResolver resolver;
141 resolver.AddCustom("AWESOME", GetDummyRegistration());
142
143 const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 1);
144 ASSERT_NE(found_registration, nullptr);
145 EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM);
146 EXPECT_TRUE(found_registration->invoke == DummyInvoke);
147 EXPECT_EQ(found_registration->version, 1);
148 }
149
TEST(MutableOpResolverTest,FindCustomName)150 TEST(MutableOpResolverTest, FindCustomName) {
151 MutableOpResolver resolver;
152 TfLiteRegistration* reg = GetDummyRegistration();
153
154 reg->custom_name = "UPDATED";
155 resolver.AddCustom(reg->custom_name, reg);
156 const TfLiteRegistration* found_registration =
157 resolver.FindOp(reg->custom_name, 1);
158
159 ASSERT_NE(found_registration, nullptr);
160 EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM);
161 EXPECT_EQ(found_registration->invoke, GetDummyRegistration()->invoke);
162 EXPECT_EQ(found_registration->version, 1);
163 EXPECT_EQ(found_registration->custom_name, "UPDATED");
164 }
165
TEST(MutableOpResolverTest,FindBuiltinName)166 TEST(MutableOpResolverTest, FindBuiltinName) {
167 MutableOpResolver resolver1;
168 TfLiteRegistration* reg = GetDummy2Registration();
169
170 reg->custom_name = "UPDATED";
171 resolver1.AddBuiltin(BuiltinOperator_ADD, reg);
172
173 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
174 GetDummy2Registration()->invoke);
175 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->prepare,
176 GetDummy2Registration()->prepare);
177 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->init,
178 GetDummy2Registration()->init);
179 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->free,
180 GetDummy2Registration()->free);
181 // custom_name for builtin ops will be nullptr
182 EXPECT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->custom_name, nullptr);
183 }
184
TEST(MutableOpResolverTest,FindMissingCustomOp)185 TEST(MutableOpResolverTest, FindMissingCustomOp) {
186 MutableOpResolver resolver;
187 resolver.AddCustom("AWESOME", GetDummyRegistration());
188
189 const TfLiteRegistration* found_registration =
190 resolver.FindOp("EXCELLENT", 1);
191 EXPECT_EQ(found_registration, nullptr);
192 }
193
TEST(MutableOpResolverTest,FindCustomOpWithUnsupportedVersion)194 TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
195 MutableOpResolver resolver;
196 resolver.AddCustom("AWESOME", GetDummyRegistration());
197
198 const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 2);
199 EXPECT_EQ(found_registration, nullptr);
200 }
201
TEST(MutableOpResolverTest,AddAll)202 TEST(MutableOpResolverTest, AddAll) {
203 MutableOpResolver resolver1;
204 resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
205 resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
206
207 MutableOpResolver resolver2;
208 resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
209 resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
210
211 // resolver2's ADD op should replace resolver1's ADD op, while augmenting
212 // non-overlapping ops.
213 resolver1.AddAll(resolver2);
214 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
215 GetDummy2Registration()->invoke);
216 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
217 GetDummy2Registration()->invoke);
218 ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
219 GetDummyRegistration()->invoke);
220 }
221
222 } // namespace
223 } // namespace tflite
224
main(int argc,char ** argv)225 int main(int argc, char** argv) {
226 ::tflite::LogToStderr();
227 ::testing::InitGoogleTest(&argc, argv);
228 return RUN_ALL_TESTS();
229 }
230