1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate.h"
16
17 #include <cstdint>
18 #include <vector>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/delegates/flex/test_util.h"
23 #include "tensorflow/lite/shared_library.h"
24
25 namespace tflite {
26 namespace flex {
27 namespace {
28
29 using ::testing::ElementsAre;
30
31 class DelegateTest : public testing::FlexModelTest {
32 public:
DelegateTest()33 DelegateTest() : delegate_(FlexDelegate::Create()) {
34 interpreter_.reset(new Interpreter(&error_reporter_));
35 }
36
~DelegateTest()37 ~DelegateTest() override {
38 // The delegate needs to be destructed after the interpreter because the
39 // interpreter references data contained in the delegate.
40 interpreter_.reset();
41 delegate_.reset();
42 }
43
ConfigureDelegate()44 void ConfigureDelegate() {
45 ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
46 kTfLiteOk);
47 }
48
49 private:
50 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> delegate_;
51 };
52
TEST_F(DelegateTest,FullGraph)53 TEST_F(DelegateTest, FullGraph) {
54 // Define the graph.
55 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
56
57 AddTfOp(testing::kUnpack, {0}, {1, 2});
58 AddTfOp(testing::kUnpack, {3}, {4, 5});
59 AddTfOp(testing::kAdd, {1, 4}, {6});
60 AddTfOp(testing::kAdd, {2, 5}, {7});
61 AddTfOp(testing::kMul, {6, 7}, {8});
62
63 // Apply the delegate.
64 ConfigureDelegate();
65
66 // Define inputs.
67 SetShape(0, {2, 2, 1});
68 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
69 SetShape(3, {2, 2, 1});
70 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
71
72 ASSERT_TRUE(Invoke());
73
74 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
75 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
76 ASSERT_EQ(GetType(8), kTfLiteFloat32);
77 }
78
TEST_F(DelegateTest,NonFloatTypeInference)79 TEST_F(DelegateTest, NonFloatTypeInference) {
80 AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
81
82 AddTfOp(testing::kAdd, {0, 1}, {2});
83
84 ConfigureDelegate();
85
86 SetShape(0, {2, 2});
87 SetTypedValues<int>(0, {1, 2, 3, 4});
88 SetShape(1, {2, 2});
89 SetTypedValues<int>(1, {4, 3, 2, 1});
90
91 ASSERT_TRUE(Invoke());
92
93 ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
94 ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
95 ASSERT_EQ(GetType(2), kTfLiteInt32);
96 }
97
TEST_F(DelegateTest,StringInference)98 TEST_F(DelegateTest, StringInference) {
99 AddTensors(3, {0, 1}, {2}, kTfLiteString, {2});
100
101 AddTfOp(testing::kAdd, {0, 1}, {2});
102
103 ConfigureDelegate();
104
105 SetShape(0, {2, 2});
106 SetStringValues(0, {"1", "2", "3", "4"});
107 SetShape(1, {2, 2});
108 SetStringValues(1, {"4", "3", "2", "1"});
109
110 ASSERT_TRUE(Invoke());
111
112 ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
113 ASSERT_THAT(GetStringValues(2), ElementsAre("14", "23", "32", "41"));
114 ASSERT_EQ(GetType(2), kTfLiteString);
115 }
116
TEST_F(DelegateTest,MixedGraph)117 TEST_F(DelegateTest, MixedGraph) {
118 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
119
120 AddTfOp(testing::kUnpack, {0}, {1, 2});
121 AddTfOp(testing::kUnpack, {3}, {4, 5});
122 AddTfOp(testing::kAdd, {1, 4}, {6});
123 AddTfOp(testing::kAdd, {2, 5}, {7});
124 AddTfLiteMulOp({6, 7}, {8});
125
126 ConfigureDelegate();
127
128 SetShape(0, {2, 2, 1});
129 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
130 SetShape(3, {2, 2, 1});
131 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
132
133 ASSERT_TRUE(Invoke());
134
135 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
136 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
137 }
138
TEST_F(DelegateTest,SplitGraph)139 TEST_F(DelegateTest, SplitGraph) {
140 AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
141
142 AddTfOp(testing::kUnpack, {0}, {1, 2});
143 AddTfOp(testing::kAdd, {1, 2}, {3});
144 AddTfOp(testing::kUnpack, {3}, {4, 5});
145
146 AddTfLiteMulOp({4, 5}, {6});
147
148 AddTfOp(testing::kUnpack, {6}, {7, 8});
149 AddTfOp(testing::kAdd, {7, 8}, {9});
150
151 ConfigureDelegate();
152
153 SetShape(0, {2, 2, 2, 1});
154 SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
155
156 ASSERT_TRUE(Invoke());
157
158 ASSERT_THAT(GetShape(9), ElementsAre(1));
159 ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
160 }
161
TEST_F(DelegateTest,OnlyTFLite)162 TEST_F(DelegateTest, OnlyTFLite) {
163 // Only TFLite single op model.
164 AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
165 AddTfLiteMulOp({0, 1}, {2});
166
167 ConfigureDelegate();
168
169 SetShape(0, {2, 2, 1});
170 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
171 SetShape(1, {2, 2, 1});
172 SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
173
174 ASSERT_TRUE(Invoke());
175
176 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
177 ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
178 }
179
TEST_F(DelegateTest,MultipleInvokeCalls)180 TEST_F(DelegateTest, MultipleInvokeCalls) {
181 // Call Invoke() multiple times on the same model.
182 AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
183 AddTfLiteMulOp({0, 1}, {2});
184
185 ConfigureDelegate();
186
187 SetShape(0, {2, 2, 1});
188 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
189 SetShape(1, {2, 2, 1});
190 SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
191
192 ASSERT_TRUE(Invoke());
193
194 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
195 ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
196
197 SetShape(0, {2, 2, 1});
198 SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
199 SetShape(1, {2, 2, 1});
200 SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
201
202 ASSERT_TRUE(Invoke());
203
204 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
205 ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
206 }
207
TEST_F(DelegateTest,MultipleInterpretersSameDelegate)208 TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
209 // Build a graph, configure the delegate and set inputs.
210 {
211 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
212 AddTfOp(testing::kUnpack, {0}, {1, 2});
213 AddTfOp(testing::kUnpack, {3}, {4, 5});
214 AddTfOp(testing::kAdd, {1, 4}, {6});
215 AddTfOp(testing::kAdd, {2, 5}, {7});
216 AddTfOp(testing::kMul, {6, 7}, {8});
217 ConfigureDelegate();
218 SetShape(0, {2, 2, 1});
219 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
220 SetShape(3, {2, 2, 1});
221 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
222 }
223
224 // Create a new interpreter, inject into the test framework and build
225 // a different graph using the *same* delegate.
226 std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_));
227 interpreter_.swap(interpreter);
228 {
229 AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
230 AddTfOp(testing::kUnpack, {0}, {1, 2});
231 AddTfOp(testing::kAdd, {1, 2}, {3});
232 AddTfOp(testing::kUnpack, {3}, {4, 5});
233 AddTfLiteMulOp({4, 5}, {6});
234 AddTfOp(testing::kUnpack, {6}, {7, 8});
235 AddTfOp(testing::kAdd, {7, 8}, {9});
236 ConfigureDelegate();
237 SetShape(0, {2, 2, 2, 1});
238 SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
239 }
240
241 // Swap back in the first interpreter and validate inference.
242 interpreter_.swap(interpreter);
243 {
244 ASSERT_TRUE(Invoke());
245 EXPECT_THAT(GetShape(8), ElementsAre(2, 1));
246 EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
247 }
248
249 // Swap in the second interpreter and validate inference.
250 interpreter_.swap(interpreter);
251 {
252 ASSERT_TRUE(Invoke());
253 EXPECT_THAT(GetShape(9), ElementsAre(1));
254 EXPECT_THAT(GetValues(9), ElementsAre(10.0f));
255 }
256 }
257
TEST_F(DelegateTest,SingleThreaded)258 TEST_F(DelegateTest, SingleThreaded) {
259 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
260 AddTfOp(testing::kUnpack, {0}, {1, 2});
261 AddTfOp(testing::kUnpack, {3}, {4, 5});
262 AddTfOp(testing::kAdd, {1, 4}, {6});
263 AddTfOp(testing::kAdd, {2, 5}, {7});
264 AddTfOp(testing::kMul, {6, 7}, {8});
265
266 // Explicitly disable multi-threading before installing the delegate.
267 interpreter_->SetNumThreads(1);
268 ConfigureDelegate();
269
270 SetShape(0, {2, 2, 1});
271 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
272 SetShape(3, {2, 2, 1});
273 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
274
275 // Invocation should behave as expected.
276 ASSERT_TRUE(Invoke());
277
278 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
279 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
280 ASSERT_EQ(GetType(8), kTfLiteFloat32);
281 }
282
TEST_F(DelegateTest,MultiThreaded)283 TEST_F(DelegateTest, MultiThreaded) {
284 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
285 AddTfOp(testing::kUnpack, {0}, {1, 2});
286 AddTfOp(testing::kUnpack, {3}, {4, 5});
287 AddTfOp(testing::kAdd, {1, 4}, {6});
288 AddTfOp(testing::kAdd, {2, 5}, {7});
289 AddTfOp(testing::kMul, {6, 7}, {8});
290
291 // Explicitly enable multi-threading before installing the delegate.
292 interpreter_->SetNumThreads(4);
293 ConfigureDelegate();
294
295 SetShape(0, {2, 2, 1});
296 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
297 SetShape(3, {2, 2, 1});
298 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
299
300 // Invocation should behave as expected.
301 ASSERT_TRUE(Invoke());
302
303 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
304 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
305 ASSERT_EQ(GetType(8), kTfLiteFloat32);
306 }
307
308 #if !defined(__ANDROID__)
TEST_F(DelegateTest,TF_AcquireFlexDelegate)309 TEST_F(DelegateTest, TF_AcquireFlexDelegate) {
310 auto TF_AcquireFlexDelegate =
311 reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
312 SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
313 ASSERT_TRUE(TF_AcquireFlexDelegate);
314 auto delegate_ptr = TF_AcquireFlexDelegate();
315 ASSERT_TRUE(delegate_ptr != nullptr);
316 }
317 #endif // !defined(__ANDROID__)
318
TEST_F(DelegateTest,StaticOutput)319 TEST_F(DelegateTest, StaticOutput) {
320 // Define the graph with input, output shapes of [2].
321 AddTensors(7, {0, 1, 2, 3}, {6}, kTfLiteFloat32, {2});
322
323 AddTfOp(testing::kAdd, {0, 2}, {4});
324 AddTfOp(testing::kAdd, {1, 3}, {5});
325 AddTfOp(testing::kMul, {4, 5}, {6});
326
327 // Apply the delegate.
328 ConfigureDelegate();
329
330 // Define inputs which matech with the original shapes.
331 SetShape(0, {2});
332 SetShape(1, {2});
333 SetShape(2, {2});
334 SetShape(3, {2});
335 SetValues(0, {1.1f, 2.2f});
336 SetValues(1, {3.3f, 4.4f});
337 SetValues(2, {1.1f, 2.2f});
338 SetValues(3, {3.3f, 4.4f});
339
340 ASSERT_TRUE(Invoke());
341
342 ASSERT_THAT(GetShape(6), ElementsAre(2));
343 ASSERT_THAT(GetValues(6), ElementsAre(14.52f, 38.72f));
344 ASSERT_EQ(GetType(6), kTfLiteFloat32);
345 // Since shapes are consistent, static output tensor is used.
346 ASSERT_FALSE(IsDynamicTensor(6));
347 }
348
TEST_F(DelegateTest,StaticOutputRFFT)349 TEST_F(DelegateTest, StaticOutputRFFT) {
350 // Define the graph with input, output shapes of [3, 257].
351 AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257});
352 int32_t rfft_length[] = {512};
353 SetConstTensor(1, {1}, kTfLiteInt32,
354 reinterpret_cast<const char*>(&rfft_length),
355 sizeof(rfft_length));
356
357 AddTfOp(testing::kRfft, {0, 1}, {2});
358 AddTfOp(testing::kImag, {2}, {3});
359
360 // Apply the delegate.
361 ConfigureDelegate();
362
363 // Define inputs.
364 SetShape(0, {3, 512});
365 SetValues(0, std::vector<float>(3 * 512, 1.0f));
366
367 ASSERT_TRUE(Invoke());
368
369 ASSERT_EQ(GetType(3), kTfLiteFloat32);
370 // Since shapes are consistent, static output tensor is used.
371 ASSERT_FALSE(IsDynamicTensor(3));
372 }
373
TEST_F(DelegateTest,DynamicOutputAfterReshape)374 TEST_F(DelegateTest, DynamicOutputAfterReshape) {
375 // Define the graph.
376 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
377
378 AddTfOp(testing::kUnpack, {0}, {1, 2});
379 AddTfOp(testing::kUnpack, {3}, {4, 5});
380 AddTfOp(testing::kAdd, {1, 4}, {6});
381 AddTfOp(testing::kAdd, {2, 5}, {7});
382 AddTfOp(testing::kMul, {6, 7}, {8});
383
384 // Apply the delegate.
385 ConfigureDelegate();
386
387 // Define inputs with reshape.
388 SetShape(0, {2, 2, 1});
389 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
390 SetShape(3, {2, 2, 1});
391 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
392
393 ASSERT_TRUE(Invoke());
394
395 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
396 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
397 ASSERT_EQ(GetType(8), kTfLiteFloat32);
398 // Since shapes are inconsistent, dynamic output tensor is used.
399 ASSERT_TRUE(IsDynamicTensor(8));
400 }
401
402 } // namespace
403 } // namespace flex
404 } // namespace tflite
405