• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/mutable_op_resolver.h"
17 
18 #include <stddef.h>
19 
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23 #include "tensorflow/lite/testing/util.h"
24 
25 namespace tflite {
26 namespace {
27 
28 // We need some dummy functions to identify the registrations.
DummyInvoke(TfLiteContext * context,TfLiteNode * node)29 TfLiteStatus DummyInvoke(TfLiteContext* context, TfLiteNode* node) {
30   return kTfLiteOk;
31 }
32 
GetDummyRegistration()33 TfLiteRegistration* GetDummyRegistration() {
34   static TfLiteRegistration registration = {
35       .init = nullptr,
36       .free = nullptr,
37       .prepare = nullptr,
38       .invoke = DummyInvoke,
39   };
40   return &registration;
41 }
42 
Dummy2Invoke(TfLiteContext * context,TfLiteNode * node)43 TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
44   return kTfLiteOk;
45 }
46 
Dummy2Prepare(TfLiteContext * context,TfLiteNode * node)47 TfLiteStatus Dummy2Prepare(TfLiteContext* context, TfLiteNode* node) {
48   return kTfLiteOk;
49 }
50 
Dummy2Init(TfLiteContext * context,const char * buffer,size_t length)51 void* Dummy2Init(TfLiteContext* context, const char* buffer, size_t length) {
52   return nullptr;
53 }
54 
Dummy2free(TfLiteContext * context,void * buffer)55 void Dummy2free(TfLiteContext* context, void* buffer) {}
56 
GetDummy2Registration()57 TfLiteRegistration* GetDummy2Registration() {
58   static TfLiteRegistration registration = {
59       .init = Dummy2Init,
60       .free = Dummy2free,
61       .prepare = Dummy2Prepare,
62       .invoke = Dummy2Invoke,
63   };
64   return &registration;
65 }
66 
TEST(MutableOpResolverTest,FinOp)67 TEST(MutableOpResolverTest, FinOp) {
68   MutableOpResolver resolver;
69   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
70 
71   const TfLiteRegistration* found_registration =
72       resolver.FindOp(BuiltinOperator_ADD, 1);
73   ASSERT_NE(found_registration, nullptr);
74   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
75   EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_ADD);
76   EXPECT_EQ(found_registration->version, 1);
77 }
78 
TEST(MutableOpResolverTest,FindMissingOp)79 TEST(MutableOpResolverTest, FindMissingOp) {
80   MutableOpResolver resolver;
81   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
82 
83   const TfLiteRegistration* found_registration =
84       resolver.FindOp(BuiltinOperator_CONV_2D, 1);
85   EXPECT_EQ(found_registration, nullptr);
86 }
87 
TEST(MutableOpResolverTest,RegisterOpWithSingleVersion)88 TEST(MutableOpResolverTest, RegisterOpWithSingleVersion) {
89   MutableOpResolver resolver;
90   // The kernel supports version 2 only
91   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2);
92 
93   const TfLiteRegistration* found_registration;
94 
95   found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
96   ASSERT_EQ(found_registration, nullptr);
97 
98   found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
99   ASSERT_NE(found_registration, nullptr);
100   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
101   EXPECT_EQ(found_registration->version, 2);
102 
103   found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
104   ASSERT_EQ(found_registration, nullptr);
105 }
106 
TEST(MutableOpResolverTest,RegisterOpWithMultipleVersions)107 TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) {
108   MutableOpResolver resolver;
109   // The kernel supports version 2 and 3
110   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
111 
112   const TfLiteRegistration* found_registration;
113 
114   found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
115   ASSERT_NE(found_registration, nullptr);
116   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
117   EXPECT_EQ(found_registration->version, 2);
118 
119   found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
120   ASSERT_NE(found_registration, nullptr);
121   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
122   EXPECT_EQ(found_registration->version, 3);
123 }
124 
TEST(MutableOpResolverTest,FindOpWithUnsupportedVersions)125 TEST(MutableOpResolverTest, FindOpWithUnsupportedVersions) {
126   MutableOpResolver resolver;
127   // The kernel supports version 2 and 3
128   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
129 
130   const TfLiteRegistration* found_registration;
131 
132   found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
133   EXPECT_EQ(found_registration, nullptr);
134 
135   found_registration = resolver.FindOp(BuiltinOperator_ADD, 4);
136   EXPECT_EQ(found_registration, nullptr);
137 }
138 
TEST(MutableOpResolverTest,FindCustomOp)139 TEST(MutableOpResolverTest, FindCustomOp) {
140   MutableOpResolver resolver;
141   resolver.AddCustom("AWESOME", GetDummyRegistration());
142 
143   const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 1);
144   ASSERT_NE(found_registration, nullptr);
145   EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM);
146   EXPECT_TRUE(found_registration->invoke == DummyInvoke);
147   EXPECT_EQ(found_registration->version, 1);
148 }
149 
TEST(MutableOpResolverTest,FindCustomName)150 TEST(MutableOpResolverTest, FindCustomName) {
151   MutableOpResolver resolver;
152   TfLiteRegistration* reg = GetDummyRegistration();
153 
154   reg->custom_name = "UPDATED";
155   resolver.AddCustom(reg->custom_name, reg);
156   const TfLiteRegistration* found_registration =
157       resolver.FindOp(reg->custom_name, 1);
158 
159   ASSERT_NE(found_registration, nullptr);
160   EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM);
161   EXPECT_EQ(found_registration->invoke, GetDummyRegistration()->invoke);
162   EXPECT_EQ(found_registration->version, 1);
163   EXPECT_EQ(found_registration->custom_name, "UPDATED");
164 }
165 
TEST(MutableOpResolverTest,FindBuiltinName)166 TEST(MutableOpResolverTest, FindBuiltinName) {
167   MutableOpResolver resolver1;
168   TfLiteRegistration* reg = GetDummy2Registration();
169 
170   reg->custom_name = "UPDATED";
171   resolver1.AddBuiltin(BuiltinOperator_ADD, reg);
172 
173   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
174             GetDummy2Registration()->invoke);
175   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->prepare,
176             GetDummy2Registration()->prepare);
177   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->init,
178             GetDummy2Registration()->init);
179   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->free,
180             GetDummy2Registration()->free);
181   // custom_name for builtin ops will be nullptr
182   EXPECT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->custom_name, nullptr);
183 }
184 
TEST(MutableOpResolverTest,FindMissingCustomOp)185 TEST(MutableOpResolverTest, FindMissingCustomOp) {
186   MutableOpResolver resolver;
187   resolver.AddCustom("AWESOME", GetDummyRegistration());
188 
189   const TfLiteRegistration* found_registration =
190       resolver.FindOp("EXCELLENT", 1);
191   EXPECT_EQ(found_registration, nullptr);
192 }
193 
TEST(MutableOpResolverTest,FindCustomOpWithUnsupportedVersion)194 TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
195   MutableOpResolver resolver;
196   resolver.AddCustom("AWESOME", GetDummyRegistration());
197 
198   const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 2);
199   EXPECT_EQ(found_registration, nullptr);
200 }
201 
TEST(MutableOpResolverTest,AddAll)202 TEST(MutableOpResolverTest, AddAll) {
203   MutableOpResolver resolver1;
204   resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
205   resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
206 
207   MutableOpResolver resolver2;
208   resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
209   resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
210 
211   // resolver2's ADD op should replace resolver1's ADD op, while augmenting
212   // non-overlapping ops.
213   resolver1.AddAll(resolver2);
214   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
215             GetDummy2Registration()->invoke);
216   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
217             GetDummy2Registration()->invoke);
218   ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
219             GetDummyRegistration()->invoke);
220 }
221 
222 }  // namespace
223 }  // namespace tflite
224 
main(int argc,char ** argv)225 int main(int argc, char** argv) {
226   ::tflite::LogToStderr();
227   ::testing::InitGoogleTest(&argc, argv);
228   return RUN_ALL_TESTS();
229 }
230