• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate.h"
16 
17 #include <cstdint>
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/delegates/flex/test_util.h"
23 #include "tensorflow/lite/shared_library.h"
24 
25 namespace tflite {
26 namespace flex {
27 namespace {
28 
29 using ::testing::ElementsAre;
30 
31 class DelegateTest : public testing::FlexModelTest {
32  public:
DelegateTest()33   DelegateTest() : delegate_(FlexDelegate::Create()) {
34     interpreter_.reset(new Interpreter(&error_reporter_));
35   }
36 
~DelegateTest()37   ~DelegateTest() override {
38     // The delegate needs to be destructed after the interpreter because the
39     // interpreter references data contained in the delegate.
40     interpreter_.reset();
41     delegate_.reset();
42   }
43 
ConfigureDelegate()44   void ConfigureDelegate() {
45     ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
46               kTfLiteOk);
47   }
48 
49  private:
50   std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> delegate_;
51 };
52 
TEST_F(DelegateTest,FullGraph)53 TEST_F(DelegateTest, FullGraph) {
54   // Define the graph.
55   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
56 
57   AddTfOp(testing::kUnpack, {0}, {1, 2});
58   AddTfOp(testing::kUnpack, {3}, {4, 5});
59   AddTfOp(testing::kAdd, {1, 4}, {6});
60   AddTfOp(testing::kAdd, {2, 5}, {7});
61   AddTfOp(testing::kMul, {6, 7}, {8});
62 
63   // Apply the delegate.
64   ConfigureDelegate();
65 
66   // Define inputs.
67   SetShape(0, {2, 2, 1});
68   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
69   SetShape(3, {2, 2, 1});
70   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
71 
72   ASSERT_TRUE(Invoke());
73 
74   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
75   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
76   ASSERT_EQ(GetType(8), kTfLiteFloat32);
77 }
78 
TEST_F(DelegateTest,NonFloatTypeInference)79 TEST_F(DelegateTest, NonFloatTypeInference) {
80   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
81 
82   AddTfOp(testing::kAdd, {0, 1}, {2});
83 
84   ConfigureDelegate();
85 
86   SetShape(0, {2, 2});
87   SetTypedValues<int>(0, {1, 2, 3, 4});
88   SetShape(1, {2, 2});
89   SetTypedValues<int>(1, {4, 3, 2, 1});
90 
91   ASSERT_TRUE(Invoke());
92 
93   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
94   ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
95   ASSERT_EQ(GetType(2), kTfLiteInt32);
96 }
97 
TEST_F(DelegateTest,StringInference)98 TEST_F(DelegateTest, StringInference) {
99   AddTensors(3, {0, 1}, {2}, kTfLiteString, {2});
100 
101   AddTfOp(testing::kAdd, {0, 1}, {2});
102 
103   ConfigureDelegate();
104 
105   SetShape(0, {2, 2});
106   SetStringValues(0, {"1", "2", "3", "4"});
107   SetShape(1, {2, 2});
108   SetStringValues(1, {"4", "3", "2", "1"});
109 
110   ASSERT_TRUE(Invoke());
111 
112   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
113   ASSERT_THAT(GetStringValues(2), ElementsAre("14", "23", "32", "41"));
114   ASSERT_EQ(GetType(2), kTfLiteString);
115 }
116 
TEST_F(DelegateTest,MixedGraph)117 TEST_F(DelegateTest, MixedGraph) {
118   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
119 
120   AddTfOp(testing::kUnpack, {0}, {1, 2});
121   AddTfOp(testing::kUnpack, {3}, {4, 5});
122   AddTfOp(testing::kAdd, {1, 4}, {6});
123   AddTfOp(testing::kAdd, {2, 5}, {7});
124   AddTfLiteMulOp({6, 7}, {8});
125 
126   ConfigureDelegate();
127 
128   SetShape(0, {2, 2, 1});
129   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
130   SetShape(3, {2, 2, 1});
131   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
132 
133   ASSERT_TRUE(Invoke());
134 
135   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
136   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
137 }
138 
TEST_F(DelegateTest,SplitGraph)139 TEST_F(DelegateTest, SplitGraph) {
140   AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
141 
142   AddTfOp(testing::kUnpack, {0}, {1, 2});
143   AddTfOp(testing::kAdd, {1, 2}, {3});
144   AddTfOp(testing::kUnpack, {3}, {4, 5});
145 
146   AddTfLiteMulOp({4, 5}, {6});
147 
148   AddTfOp(testing::kUnpack, {6}, {7, 8});
149   AddTfOp(testing::kAdd, {7, 8}, {9});
150 
151   ConfigureDelegate();
152 
153   SetShape(0, {2, 2, 2, 1});
154   SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
155 
156   ASSERT_TRUE(Invoke());
157 
158   ASSERT_THAT(GetShape(9), ElementsAre(1));
159   ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
160 }
161 
TEST_F(DelegateTest,OnlyTFLite)162 TEST_F(DelegateTest, OnlyTFLite) {
163   // Only TFLite single op model.
164   AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
165   AddTfLiteMulOp({0, 1}, {2});
166 
167   ConfigureDelegate();
168 
169   SetShape(0, {2, 2, 1});
170   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
171   SetShape(1, {2, 2, 1});
172   SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
173 
174   ASSERT_TRUE(Invoke());
175 
176   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
177   ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
178 }
179 
TEST_F(DelegateTest,MultipleInvokeCalls)180 TEST_F(DelegateTest, MultipleInvokeCalls) {
181   // Call Invoke() multiple times on the same model.
182   AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
183   AddTfLiteMulOp({0, 1}, {2});
184 
185   ConfigureDelegate();
186 
187   SetShape(0, {2, 2, 1});
188   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
189   SetShape(1, {2, 2, 1});
190   SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
191 
192   ASSERT_TRUE(Invoke());
193 
194   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
195   ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
196 
197   SetShape(0, {2, 2, 1});
198   SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
199   SetShape(1, {2, 2, 1});
200   SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
201 
202   ASSERT_TRUE(Invoke());
203 
204   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
205   ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
206 }
207 
TEST_F(DelegateTest,MultipleInterpretersSameDelegate)208 TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
209   // Build a graph, configure the delegate and set inputs.
210   {
211     AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
212     AddTfOp(testing::kUnpack, {0}, {1, 2});
213     AddTfOp(testing::kUnpack, {3}, {4, 5});
214     AddTfOp(testing::kAdd, {1, 4}, {6});
215     AddTfOp(testing::kAdd, {2, 5}, {7});
216     AddTfOp(testing::kMul, {6, 7}, {8});
217     ConfigureDelegate();
218     SetShape(0, {2, 2, 1});
219     SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
220     SetShape(3, {2, 2, 1});
221     SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
222   }
223 
224   // Create a new interpreter, inject into the test framework and build
225   // a different graph using the *same* delegate.
226   std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_));
227   interpreter_.swap(interpreter);
228   {
229     AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
230     AddTfOp(testing::kUnpack, {0}, {1, 2});
231     AddTfOp(testing::kAdd, {1, 2}, {3});
232     AddTfOp(testing::kUnpack, {3}, {4, 5});
233     AddTfLiteMulOp({4, 5}, {6});
234     AddTfOp(testing::kUnpack, {6}, {7, 8});
235     AddTfOp(testing::kAdd, {7, 8}, {9});
236     ConfigureDelegate();
237     SetShape(0, {2, 2, 2, 1});
238     SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
239   }
240 
241   // Swap back in the first interpreter and validate inference.
242   interpreter_.swap(interpreter);
243   {
244     ASSERT_TRUE(Invoke());
245     EXPECT_THAT(GetShape(8), ElementsAre(2, 1));
246     EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
247   }
248 
249   // Swap in the second interpreter and validate inference.
250   interpreter_.swap(interpreter);
251   {
252     ASSERT_TRUE(Invoke());
253     EXPECT_THAT(GetShape(9), ElementsAre(1));
254     EXPECT_THAT(GetValues(9), ElementsAre(10.0f));
255   }
256 }
257 
TEST_F(DelegateTest,SingleThreaded)258 TEST_F(DelegateTest, SingleThreaded) {
259   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
260   AddTfOp(testing::kUnpack, {0}, {1, 2});
261   AddTfOp(testing::kUnpack, {3}, {4, 5});
262   AddTfOp(testing::kAdd, {1, 4}, {6});
263   AddTfOp(testing::kAdd, {2, 5}, {7});
264   AddTfOp(testing::kMul, {6, 7}, {8});
265 
266   // Explicitly disable multi-threading before installing the delegate.
267   interpreter_->SetNumThreads(1);
268   ConfigureDelegate();
269 
270   SetShape(0, {2, 2, 1});
271   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
272   SetShape(3, {2, 2, 1});
273   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
274 
275   // Invocation should behave as expected.
276   ASSERT_TRUE(Invoke());
277 
278   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
279   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
280   ASSERT_EQ(GetType(8), kTfLiteFloat32);
281 }
282 
TEST_F(DelegateTest,MultiThreaded)283 TEST_F(DelegateTest, MultiThreaded) {
284   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
285   AddTfOp(testing::kUnpack, {0}, {1, 2});
286   AddTfOp(testing::kUnpack, {3}, {4, 5});
287   AddTfOp(testing::kAdd, {1, 4}, {6});
288   AddTfOp(testing::kAdd, {2, 5}, {7});
289   AddTfOp(testing::kMul, {6, 7}, {8});
290 
291   // Explicitly enable multi-threading before installing the delegate.
292   interpreter_->SetNumThreads(4);
293   ConfigureDelegate();
294 
295   SetShape(0, {2, 2, 1});
296   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
297   SetShape(3, {2, 2, 1});
298   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
299 
300   // Invocation should behave as expected.
301   ASSERT_TRUE(Invoke());
302 
303   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
304   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
305   ASSERT_EQ(GetType(8), kTfLiteFloat32);
306 }
307 
308 #if !defined(__ANDROID__)
TEST_F(DelegateTest,TF_AcquireFlexDelegate)309 TEST_F(DelegateTest, TF_AcquireFlexDelegate) {
310   auto TF_AcquireFlexDelegate =
311       reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
312           SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
313   ASSERT_TRUE(TF_AcquireFlexDelegate);
314   auto delegate_ptr = TF_AcquireFlexDelegate();
315   ASSERT_TRUE(delegate_ptr != nullptr);
316 }
317 #endif  // !defined(__ANDROID__)
318 
TEST_F(DelegateTest,StaticOutput)319 TEST_F(DelegateTest, StaticOutput) {
320   // Define the graph with input, output shapes of [2].
321   AddTensors(7, {0, 1, 2, 3}, {6}, kTfLiteFloat32, {2});
322 
323   AddTfOp(testing::kAdd, {0, 2}, {4});
324   AddTfOp(testing::kAdd, {1, 3}, {5});
325   AddTfOp(testing::kMul, {4, 5}, {6});
326 
327   // Apply the delegate.
328   ConfigureDelegate();
329 
330   // Define inputs which matech with the original shapes.
331   SetShape(0, {2});
332   SetShape(1, {2});
333   SetShape(2, {2});
334   SetShape(3, {2});
335   SetValues(0, {1.1f, 2.2f});
336   SetValues(1, {3.3f, 4.4f});
337   SetValues(2, {1.1f, 2.2f});
338   SetValues(3, {3.3f, 4.4f});
339 
340   ASSERT_TRUE(Invoke());
341 
342   ASSERT_THAT(GetShape(6), ElementsAre(2));
343   ASSERT_THAT(GetValues(6), ElementsAre(14.52f, 38.72f));
344   ASSERT_EQ(GetType(6), kTfLiteFloat32);
345   // Since shapes are consistent, static output tensor is used.
346   ASSERT_FALSE(IsDynamicTensor(6));
347 }
348 
TEST_F(DelegateTest,StaticOutputRFFT)349 TEST_F(DelegateTest, StaticOutputRFFT) {
350   // Define the graph with input, output shapes of [3, 257].
351   AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257});
352   int32_t rfft_length[] = {512};
353   SetConstTensor(1, {1}, kTfLiteInt32,
354                  reinterpret_cast<const char*>(&rfft_length),
355                  sizeof(rfft_length));
356 
357   AddTfOp(testing::kRfft, {0, 1}, {2});
358   AddTfOp(testing::kImag, {2}, {3});
359 
360   // Apply the delegate.
361   ConfigureDelegate();
362 
363   // Define inputs.
364   SetShape(0, {3, 512});
365   SetValues(0, std::vector<float>(3 * 512, 1.0f));
366 
367   ASSERT_TRUE(Invoke());
368 
369   ASSERT_EQ(GetType(3), kTfLiteFloat32);
370   // Since shapes are consistent, static output tensor is used.
371   ASSERT_FALSE(IsDynamicTensor(3));
372 }
373 
TEST_F(DelegateTest,DynamicOutputAfterReshape)374 TEST_F(DelegateTest, DynamicOutputAfterReshape) {
375   // Define the graph.
376   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
377 
378   AddTfOp(testing::kUnpack, {0}, {1, 2});
379   AddTfOp(testing::kUnpack, {3}, {4, 5});
380   AddTfOp(testing::kAdd, {1, 4}, {6});
381   AddTfOp(testing::kAdd, {2, 5}, {7});
382   AddTfOp(testing::kMul, {6, 7}, {8});
383 
384   // Apply the delegate.
385   ConfigureDelegate();
386 
387   // Define inputs with reshape.
388   SetShape(0, {2, 2, 1});
389   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
390   SetShape(3, {2, 2, 1});
391   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
392 
393   ASSERT_TRUE(Invoke());
394 
395   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
396   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
397   ASSERT_EQ(GetType(8), kTfLiteFloat32);
398   // Since shapes are inconsistent, dynamic output tensor is used.
399   ASSERT_TRUE(IsDynamicTensor(8));
400 }
401 
402 }  // namespace
403 }  // namespace flex
404 }  // namespace tflite
405