• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/interpreter.h"
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/core/api/error_reporter.h"
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23 #include "tensorflow/lite/string_util.h"
24 #include "tensorflow/lite/testing/util.h"
25 
26 namespace tflite {
27 
28 // InterpreterTest is a friend of Interpreter, so it can access context_.
29 class InterpreterTest : public ::testing::Test {
30  public:
31   template <typename Delegate>
ModifyGraphWithDelegate(Interpreter * interpreter,std::unique_ptr<Delegate> delegate)32   static TfLiteStatus ModifyGraphWithDelegate(
33       Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
34     Interpreter::TfLiteDelegatePtr tflite_delegate(
35         delegate.release(), [](TfLiteDelegate* delegate) {
36           delete reinterpret_cast<Delegate*>(delegate);
37         });
38     return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate));
39   }
40 
41  protected:
GetInterpreterContext()42   TfLiteContext* GetInterpreterContext() { return interpreter_.context_; }
43 
44   Interpreter interpreter_;
45 };
46 
47 namespace ops {
48 namespace builtin {
49 TfLiteRegistration* Register_PADV2();
50 TfLiteRegistration* Register_NEG();
51 }  // namespace builtin
52 }  // namespace ops
53 namespace {
54 
55 using ::testing::IsEmpty;
56 
57 // Make an interpreter that has no tensors and no nodes
TEST(BasicInterpreter,ZeroInterpreter)58 TEST(BasicInterpreter, ZeroInterpreter) {
59   testing::internal::CaptureStderr();
60 
61   Interpreter interpreter;
62   EXPECT_THAT(testing::internal::GetCapturedStderr(),
63               testing::HasSubstr("INFO: Initialized TensorFlow Lite runtime"));
64 
65   interpreter.SetInputs({});
66   interpreter.SetOutputs({});
67   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
68   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
69 
70   // Creating a new interpreter should not redundantly log runtime init.
71   testing::internal::CaptureStderr();
72   Interpreter interpreter2;
73   EXPECT_THAT(testing::internal::GetCapturedStderr(), IsEmpty());
74 }
75 
76 // Test various error conditions.
TEST(BasicInterpreter,InvokeInvalidModel)77 TEST(BasicInterpreter, InvokeInvalidModel) {
78   Interpreter interpreter;
79   ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
80   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
81   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
82 }
83 
TEST(BasicInterpreter,TestAllocateTensorsResetVariableTensors)84 TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) {
85   Interpreter interpreter;
86   int tensor_index;
87   ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk);
88   constexpr int kTensorSize = 16;
89   TfLiteQuantizationParams quant;
90   interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "",
91                                            {kTensorSize}, quant, true);
92   interpreter.SetVariables({tensor_index});
93   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
94   TfLiteTensor* tensor = interpreter.tensor(tensor_index);
95   // Ensure that variable tensors are reset to zero.
96   for (int i = 0; i < kTensorSize; ++i) {
97     ASSERT_EQ(tensor->data.f[i], 0.0f);
98   }
99 }
100 
101 // Test size accessor functions.
TEST(BasicInterpreter,TestSizeFunctions)102 TEST(BasicInterpreter, TestSizeFunctions) {
103   Interpreter interpreter;
104   int base_index;
105   ASSERT_EQ(interpreter.nodes_size(), 0);
106   ASSERT_EQ(interpreter.tensors_size(), 0);
107   ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk);
108   ASSERT_EQ(interpreter.tensors_size(), 2);
109   ASSERT_EQ(base_index, 0);
110   ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk);
111   ASSERT_EQ(interpreter.tensors_size(), 5);
112   ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
113   ASSERT_EQ(interpreter.tensors_size(), 6);
114   ASSERT_EQ(base_index, 2);
115 }
116 
117 // Test if invalid indices make a model inconsistent (and conversely if
118 // valid indices keep a model consistent).
TEST(BasicInterpreter,InconsistentModel)119 TEST(BasicInterpreter, InconsistentModel) {
120   // Invalid inputs
121   {
122     Interpreter interpreter;
123     ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk);
124     ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
125     ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
126     ASSERT_EQ(interpreter.inputs(), std::vector<int>());
127   }
128   // Invalid outputs
129   {
130     Interpreter interpreter;
131     ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk);
132     ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
133     ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
134     ASSERT_EQ(interpreter.outputs(), std::vector<int>());
135   }
136   // Invalid node inputs
137   {
138     Interpreter interpreter;
139     TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
140     ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr,
141                                                 &registration),
142               kTfLiteOk);
143     ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
144     ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
145   }
146   // Valid inputs and outputs and a node with valid inputs and outputs
147   {
148     Interpreter interpreter;
149     ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
150     TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
151     ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
152     ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
153     ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
154                                                 &registration),
155               kTfLiteOk);
156   }
157 }
158 
159 // Make an interpreter that has one tensor but no ops
TEST(BasicInterpreter,CheckAllocate)160 TEST(BasicInterpreter, CheckAllocate) {
161   struct {
162     TfLiteType type;
163     size_t size;
164   } cases[] = {
165       {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
166       {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)},
167       {kTfLiteInt16, sizeof(int16_t)},
168   };
169 
170   for (auto test : cases) {
171     Interpreter interpreter;
172     ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
173     interpreter.SetInputs({0, 1});
174     interpreter.SetOutputs({});
175     TfLiteQuantizationParams quant;
176 
177     interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant);
178     interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant);
179     ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
180     ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size);
181     ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
182     ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size);
183     ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
184   }
185 }
186 
TEST(BasicInterpreter,CheckQuantization)187 TEST(BasicInterpreter, CheckQuantization) {
188   Interpreter interpreter;
189   ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
190   interpreter.SetInputs({0, 1});
191   interpreter.SetOutputs({});
192   TfLiteType tensor_type = kTfLiteInt8;
193   const uint8_t int8s[] = {3, 4};
194   float scale = 0.5f;
195   int32_t zero_point = 12;
196 
197   TfLiteQuantization rw_quantization;
198   rw_quantization.type = kTfLiteAffineQuantization;
199   auto* rw_affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
200       malloc(sizeof(TfLiteAffineQuantization)));
201   rw_affine_quantization->scale = TfLiteFloatArrayCreate(1);
202   rw_affine_quantization->zero_point = TfLiteIntArrayCreate(1);
203   rw_affine_quantization->scale->data[0] = scale;
204   rw_affine_quantization->zero_point->data[0] = zero_point;
205   rw_quantization.params = rw_affine_quantization;
206 
207   TfLiteQuantization ro_quantization;
208   ro_quantization.type = kTfLiteAffineQuantization;
209   auto* ro_affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
210       malloc(sizeof(TfLiteAffineQuantization)));
211   ro_affine_quantization->scale = TfLiteFloatArrayCreate(1);
212   ro_affine_quantization->zero_point = TfLiteIntArrayCreate(1);
213   ro_affine_quantization->scale->data[0] = scale;
214   ro_affine_quantization->zero_point->data[0] = zero_point;
215   ro_quantization.params = ro_affine_quantization;
216 
217   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, tensor_type, "", {3},
218                                                      rw_quantization),
219             kTfLiteOk);
220   ASSERT_EQ(interpreter.SetTensorParametersReadOnly(
221                 1, tensor_type, "", {2}, ro_quantization,
222                 reinterpret_cast<const char*>(int8s), 2),
223             kTfLiteOk);
224   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
225   // Check that the legacy scale and zero_point are set correctly.
226   ASSERT_EQ(interpreter.tensor(0)->params.scale, scale);
227   ASSERT_EQ(interpreter.tensor(0)->params.zero_point, zero_point);
228   ASSERT_EQ(interpreter.tensor(0)->quantization.type, rw_quantization.type);
229   ASSERT_EQ(interpreter.tensor(1)->params.scale, scale);
230   ASSERT_EQ(interpreter.tensor(1)->params.zero_point, zero_point);
231   ASSERT_EQ(interpreter.tensor(1)->quantization.type, ro_quantization.type);
232 }
233 
TEST(BasicInterpreter,CheckResize)234 TEST(BasicInterpreter, CheckResize) {
235   const float floats[] = {-3., -4.};
236   const int32_t int32s[] = {-3, -4};
237   const uint8_t uint8s[] = {3, 4};
238   const int64_t int64s[] = {6, -7};
239   const int16_t int16s[] = {8, -9};
240 
241   struct {
242     TfLiteType type;
243     size_t size;
244     const char* array;
245   } cases[] = {
246       {kTfLiteFloat32, sizeof(float), reinterpret_cast<const char*>(floats)},
247       {kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
248       {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
249       {kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
250       {kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)},
251   };
252 
253   for (auto test : cases) {
254     Interpreter interpreter;
255 
256     ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
257     interpreter.SetInputs({0, 1});
258     interpreter.SetOutputs({});
259     TfLiteQuantizationParams quant;
260 
261     ASSERT_EQ(
262         interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant),
263         kTfLiteOk);
264     ASSERT_EQ(interpreter.SetTensorParametersReadOnly(
265                   1, test.type, "", {2}, quant, test.array, 2 * test.size),
266               kTfLiteOk);
267     ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
268     ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk);
269     // Resizing a mmapped tensor is not allowed and should produce error.
270     ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk);
271     // Set the tensor to be mmapped but with a buffer size that is insufficient
272     // to match the dimensionality.
273     ASSERT_NE(interpreter.SetTensorParametersReadOnly(
274                   1, test.type, "", {2}, quant, test.array, 1 * test.size),
275               kTfLiteOk);
276     // Allocating should work since we should have our last correct array
277     // values in place.
278     ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
279   }
280 }
281 
TEST(BasicInterpreter,CheckAlignment)282 TEST(BasicInterpreter, CheckAlignment) {
283   struct {
284     TfLiteType type;
285   } cases[] = {
286       {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8},
287       {kTfLiteInt64},   {kTfLiteInt16},
288   };
289 
290   for (auto test : cases) {
291     Interpreter interpreter;
292 
293     ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk);
294 
295     for (int i = 0; i < 4; i++) {
296       TfLiteQuantizationParams quant;
297       interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1},
298                                                quant);
299     }
300     interpreter.AllocateTensors();
301     for (int i = 0; i < 4; i++) {
302       const TfLiteTensor& tensor = *interpreter.tensor(i);
303       ASSERT_EQ(reinterpret_cast<intptr_t>(tensor.data.raw) % 4, 0);
304     }
305   }
306 }
307 
TEST(BasicInterpreter,CheckArenaAllocation)308 TEST(BasicInterpreter, CheckArenaAllocation) {
309   Interpreter interpreter;
310   ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk);
311 
312   TfLiteQuantizationParams quant;
313   TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
314 
315   std::vector<int> sizes{2048, 4096, 1023, 2047, 1021,
316                          2047, 1023, 2046, 0,    2048};
317   for (int i = 0; i < sizes.size(); ++i) {
318     interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]},
319                                              quant);
320   }
321   interpreter.SetInputs({0, 1});
322   interpreter.SetOutputs({9, 4});
323   interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, &reg);
324   interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, &reg);
325   interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, &reg);
326   interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, &reg);
327   interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, &reg);
328 
329   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
330 
331   ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw);
332   ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw);
333   ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw);
334   ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw);
335   ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw);
336   ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw);
337   ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw);
338   // #7 is the one with the largest pointer.
339   ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr);
340   ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
341 }
342 
TEST(BasicInterpreter,BufferAccess)343 TEST(BasicInterpreter, BufferAccess) {
344   Interpreter interpreter;
345   ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
346   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
347 
348   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
349                 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
350             kTfLiteOk);
351   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
352   // Verify we get a valid pointer.r
353   ASSERT_NE(interpreter.typed_tensor<float>(0), nullptr);
354   // Verify incorrect pointer will not returned.
355   ASSERT_EQ(interpreter.typed_tensor<int>(0), nullptr);
356   // Verify that raw c interface ptr matches safe interface.
357   ASSERT_EQ(interpreter.typed_tensor<float>(0), interpreter.tensor(0)->data.f);
358 }
359 
TEST(BasicInterpreter,NoOpInterpreter)360 TEST(BasicInterpreter, NoOpInterpreter) {
361   Interpreter interpreter;
362   ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
363   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
364   ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
365 
366   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
367                 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
368             kTfLiteOk);
369 
370   ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}),
371             kTfLiteOk);
372   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
373   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
374 }
375 
TEST(BasicInterpreter,RedundantAllocateTensors)376 TEST(BasicInterpreter, RedundantAllocateTensors) {
377   Interpreter interpreter;
378   ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
379   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
380 
381   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
382                 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
383             kTfLiteOk);
384 
385   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
386   const auto data_raw = interpreter.tensor(0)->data.raw;
387   ASSERT_NE(data_raw, nullptr);
388 
389   // A redundant allocation request should have no impact.
390   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
391   ASSERT_EQ(interpreter.tensor(0)->data.raw, data_raw);
392 }
393 
TEST(BasicInterpreter,RedundantAllocateTensorsWithDynamicInputs)394 TEST(BasicInterpreter, RedundantAllocateTensorsWithDynamicInputs) {
395   Interpreter interpreter;
396   TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
397   ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
398   interpreter.SetInputs({0});
399   interpreter.SetOutputs({1});
400   interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg);
401 
402   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
403                 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
404             kTfLiteOk);
405   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
406                 1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
407             kTfLiteOk);
408 
409   // Configure the input tensor as dynamic.
410   interpreter.tensor(0)->data.raw = nullptr;
411   interpreter.tensor(0)->allocation_type = kTfLiteDynamic;
412 
413   ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}),
414             kTfLiteOk);
415   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
416   ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
417 
418   // Reset the output tensor's buffer.
419   interpreter.tensor(1)->data.raw = nullptr;
420 
421   // A redundant allocation request should be honored, as the input tensor
422   // was marked dynamic.
423   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
424   ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
425 }
426 
TEST(BasicInterpreter,ResizingTensors)427 TEST(BasicInterpreter, ResizingTensors) {
428   Interpreter interpreter;
429   ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
430   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
431   ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
432 
433   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
434                 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
435             kTfLiteOk);
436 
437   int t = interpreter.inputs()[0];
438   TfLiteTensor* tensor = interpreter.tensor(t);
439 
440   ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
441   EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
442   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
443 
444   tensor->data.f[5] = 0.123f;
445 
446   // Changing from kTfLiteArenaRw to kTfLiteDynamic is quite complicate: we need
447   // to unset data.raw, otherwise Realloc will try to free that memory.
448   tensor->data.raw = nullptr;
449   tensor->allocation_type = kTfLiteDynamic;
450 
451   ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 4}), kTfLiteOk);
452   EXPECT_EQ(tensor->bytes, 8 * sizeof(float));
453   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
454 
455   ASSERT_EQ(interpreter.ResizeInputTensor(t, {}), kTfLiteOk);
456   EXPECT_EQ(tensor->bytes, 1 * sizeof(float));
457   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
458 
459   ASSERT_EQ(interpreter.ResizeInputTensor(t, {0}), kTfLiteOk);
460   EXPECT_EQ(tensor->bytes, 0);
461   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
462 
463   ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 0}), kTfLiteOk);
464   EXPECT_EQ(tensor->bytes, 0);
465   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
466 
467   // TODO(ahentz): We shouldn't have to force reallocation, but
468   // ResizeInputTensor doesn't realloc dynamic tensors. Also note that
469   // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op.
470   TfLiteTensorRealloc(9 * sizeof(float), tensor);
471   tensor->data.f[7] = 0.123f;
472 
473   ASSERT_EQ(interpreter.ResizeInputTensor(t, {2, 2, 4}), kTfLiteOk);
474   EXPECT_EQ(tensor->bytes, 16 * sizeof(float));
475   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
476 
477   // TODO(ahentz): We shouldn't have to force reallocation, but
478   // ResizeInputTensor doesn't realloc dynamic tensors. Also note that
479   // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op.
480   TfLiteTensorRealloc(17 * sizeof(float), tensor);
481   tensor->data.f[15] = 0.123f;
482 }
483 
TEST(BasicInterpreter,NoopResizingTensors)484 TEST(BasicInterpreter, NoopResizingTensors) {
485   Interpreter interpreter;
486   ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
487   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
488   ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
489 
490   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
491                 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
492             kTfLiteOk);
493 
494   int t = interpreter.inputs()[0];
495   TfLiteTensor* tensor = interpreter.tensor(t);
496 
497   ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
498   EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
499   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
500   tensor->data.f[5] = 0.123f;
501 
502   // Resizing to the same size should not trigger re-allocation.
503   ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
504   EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
505   ASSERT_NE(tensor->data.raw, nullptr);
506   ASSERT_EQ(tensor->data.f[5], 0.123f);
507 
508   // Explicitly allocating should be a no-op, as no resize was performed.
509   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
510   EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
511   ASSERT_NE(tensor->data.raw, nullptr);
512   ASSERT_EQ(tensor->data.f[5], 0.123f);
513 }
514 
TEST(BasicInterpreter,OneOpInterpreter)515 TEST(BasicInterpreter, OneOpInterpreter) {
516   Interpreter interpreter;
517   ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
518   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
519   ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
520 
521   TfLiteQuantizationParams quantized;
522   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1",
523                                                      {3}, quantized),
524             kTfLiteOk);
525   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0",
526                                                      {3}, quantized),
527             kTfLiteOk);
528 
529   ASSERT_EQ(interpreter.GetInputName(0), "in1");
530   ASSERT_EQ(interpreter.GetOutputName(0), "out0");
531 
532   TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
533   reg.init = [](TfLiteContext* context, const char*, size_t) -> void* {
534     auto* first_new_tensor = new int;
535     context->AddTensors(context, 2, first_new_tensor);
536     return first_new_tensor;
537   };
538   reg.free = [](TfLiteContext* context, void* buffer) {
539     delete reinterpret_cast<int*>(buffer);
540   };
541   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
542     auto* first_new_tensor = reinterpret_cast<int*>(node->user_data);
543 
544     TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
545     TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
546 
547     TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
548     TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize));
549 
550     TfLiteIntArrayFree(node->temporaries);
551     node->temporaries = TfLiteIntArrayCreate(2);
552     for (int i = 0; i < 2; ++i) {
553       node->temporaries->data[i] = *(first_new_tensor) + i;
554     }
555 
556     auto setup_temporary = [&](int id) {
557       TfLiteTensor* tmp = &context->tensors[id];
558       tmp->type = kTfLiteFloat32;
559       tmp->allocation_type = kTfLiteArenaRw;
560       return context->ResizeTensor(context, tmp,
561                                    TfLiteIntArrayCopy(tensor0->dims));
562     };
563     TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0]));
564     TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1]));
565 
566     return kTfLiteOk;
567   };
568   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
569     TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
570 
571     auto populate = [&](int id) {
572       TfLiteTensor* t = &context->tensors[id];
573       int num = a0->dims->data[0];
574       for (int i = 0; i < num; i++) {
575         t->data.f[i] = a0->data.f[i];
576       }
577     };
578 
579     populate(node->outputs->data[0]);
580     populate(node->temporaries->data[0]);
581     populate(node->temporaries->data[1]);
582     return kTfLiteOk;
583   };
584   ASSERT_EQ(
585       interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg),
586       kTfLiteOk);
587   ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
588   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
589 
590   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
591 }
592 
593 // Forcefully divides tensor allocation in three steps: one before invocation
594 // and two more at invocation time. This happens because we use string tensors
595 // and their sizes can't be determined until invocation time.
TEST(BasicInterpreter,ThreeStepAllocate)596 TEST(BasicInterpreter, ThreeStepAllocate) {
597   Interpreter interpreter;
598   ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk);
599   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
600   ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk);
601 
602   TfLiteQuantizationParams quantized;
603   char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'};
604   // Read only string tensor.
605   ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1},
606                                                     quantized, data, 15),
607             kTfLiteOk);
608   // Read-write string tensor.
609   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1},
610                                                      quantized),
611             kTfLiteOk);
612   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1},
613                                                      quantized),
614             kTfLiteOk);
615   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1},
616                                                      quantized),
617             kTfLiteOk);
618   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1},
619                                                      quantized),
620             kTfLiteOk);
621 
622   // String-in String-out node.
623   TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
624   reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
625     TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
626     TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
627     DynamicBuffer buf;
628     StringRef str_ref = GetString(input, 0);
629     buf.AddString(str_ref);
630     buf.WriteToTensorAsVector(output);
631     return kTfLiteOk;
632   };
633 
634   // String-in Int-out node.
635   TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr};
636   reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) {
637     TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
638     TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
639     outputSize->data[0] = 1;
640     return context->ResizeTensor(context, output, outputSize);
641   };
642   reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) {
643     TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
644     TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
645     a1->data.i32[0] = a0->bytes;
646     return kTfLiteOk;
647   };
648 
649   ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
650                                               &reg_copy),
651             kTfLiteOk);
652   ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr,
653                                               &reg_len),
654             kTfLiteOk);
655   ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr,
656                                               &reg_copy),
657             kTfLiteOk);
658   ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr,
659                                               &reg_len),
660             kTfLiteOk);
661 
662   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
663   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
664 
665   ASSERT_EQ(interpreter.tensor(0)->bytes, 15);
666   ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
667   ASSERT_EQ(interpreter.tensor(1)->bytes, 15);
668   ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
669   ASSERT_EQ(interpreter.tensor(3)->bytes, 15);
670   ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr);
671   ASSERT_EQ(interpreter.tensor(2)->bytes, 4);
672   ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15);
673   ASSERT_EQ(interpreter.tensor(4)->bytes, 4);
674   ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15);
675 }
676 
TEST(BasicInterpreter,AllocateTwice)677 TEST(BasicInterpreter, AllocateTwice) {
678   Interpreter interpreter;
679   ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
680   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
681   ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
682 
683   TfLiteQuantizationParams quantized;
684   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
685                                                      quantized),
686             kTfLiteOk);
687   ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
688                                                      quantized),
689             kTfLiteOk);
690 
691   TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
692   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
693     TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
694     TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
695     TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
696     return context->ResizeTensor(context, tensor1, newSize);
697   };
698   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
699     TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
700     TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
701     int num = a0->dims->data[0];
702     for (int i = 0; i < num; i++) {
703       a1->data.f[i] = a0->data.f[i];
704     }
705     return kTfLiteOk;
706   };
707   ASSERT_EQ(
708       interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg),
709       kTfLiteOk);
710   ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
711   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
712   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
713   char* old_tensor0_ptr = interpreter.tensor(0)->data.raw;
714   char* old_tensor1_ptr = interpreter.tensor(1)->data.raw;
715 
716   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
717   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
718   ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw);
719   ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw);
720 }
721 
TEST(BasicInterpreter,TestNullErrorReporter)722 TEST(BasicInterpreter, TestNullErrorReporter) {
723   TestErrorReporter reporter;
724   Interpreter interpreter;
725 }
726 
TEST(BasicInterpreter,TestCustomErrorReporter)727 TEST(BasicInterpreter, TestCustomErrorReporter) {
728   TestErrorReporter reporter;
729   Interpreter interpreter(&reporter);
730   ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
731   ASSERT_EQ(reporter.error_messages(),
732             "Invoke called on model that is not ready.");
733   ASSERT_EQ(reporter.num_calls(), 1);
734 }
735 
TEST(BasicInterpreter,TestUnsupportedDelegateFunctions)736 TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) {
737   Interpreter interpreter;
738   ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
739   TfLiteRegistration registration = {
740       .init = nullptr, .free = nullptr, .prepare = nullptr, .invoke = nullptr};
741   // These functions are only supported inside Delegate's Prepare function.
742   // The test verifies that these functions returns `kTfLiteError`, but not
743   // `kTfLiteOk` or just crashes.
744   registration.prepare = [](TfLiteContext* context, TfLiteNode* node) {
745     {
746       TfLiteIntArray* execution_plan;
747       EXPECT_EQ(context->GetExecutionPlan(context, &execution_plan),
748                 kTfLiteError);
749     }
750     {
751       TfLiteNode* node;
752       TfLiteRegistration* registration;
753       EXPECT_EQ(
754           context->GetNodeAndRegistration(context, 0, &node, &registration),
755           kTfLiteError);
756     }
757     {
758       TfLiteRegistration delegate_registration = {nullptr, nullptr, nullptr,
759                                                   nullptr};
760       TfLiteIntArray nodes_to_replace;
761       nodes_to_replace.size = 0;
762       EXPECT_EQ(context->ReplaceNodeSubsetsWithDelegateKernels(
763                     context, delegate_registration, &nodes_to_replace, nullptr),
764                 kTfLiteError);
765     }
766     return kTfLiteError;
767   };
768   ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
769   ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
770   ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
771                                               &registration),
772             kTfLiteOk);
773   EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError);
774 }
775 
TEST(BasicInterpreter,DynamicTensorsResizeDescendants)776 TEST(BasicInterpreter, DynamicTensorsResizeDescendants) {
777   // Assemble a graph with a node that has dynamically sized output (via the
778   // pad op), followed by a node with a standard element-wise op (negate).
779   Interpreter interpreter;
780   interpreter.AddTensors(4);
781   interpreter.SetInputs({0, 1});
782   interpreter.SetOutputs({3});
783   TfLiteQuantizationParams quant;
784   interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1},
785                                            quant);
786   interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant);
787   interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant);
788   interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant);
789 
790   TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2();
791   TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG();
792   interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op);
793   interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op);
794   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
795 
796   // Configure [[2,2],[4,4]] padding and execute the graph.
797   interpreter.typed_tensor<int>(1)[0] = 2;
798   interpreter.typed_tensor<int>(1)[1] = 2;
799   interpreter.typed_tensor<int>(1)[2] = 2;
800   interpreter.typed_tensor<int>(1)[3] = 2;
801   interpreter.typed_tensor<int>(1)[4] = 0;
802   interpreter.typed_tensor<int>(1)[5] = 0;
803   interpreter.typed_tensor<int>(1)[6] = 0;
804   interpreter.typed_tensor<int>(1)[7] = 0;
805   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
806 
807   // Both the output and intermediate tensor sizes should reflect the output
808   // from the dynamic pad operation.
809   ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6);
810   ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6);
811 
812   // Now configure [[4,4],[6,6]] padding and execute the graph.
813   interpreter.typed_tensor<int>(1)[0] = 4;
814   interpreter.typed_tensor<int>(1)[1] = 4;
815   interpreter.typed_tensor<int>(1)[2] = 6;
816   interpreter.typed_tensor<int>(1)[3] = 6;
817   interpreter.typed_tensor<int>(1)[4] = 0;
818   interpreter.typed_tensor<int>(1)[5] = 0;
819   interpreter.typed_tensor<int>(1)[6] = 0;
820   interpreter.typed_tensor<int>(1)[7] = 0;
821   ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
822 
823   // Again, the output and intermediate tensor sizes should reflect the *new*
824   // resize from the latest pad operation.
825   ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14);
826   ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14);
827 }
828 
TEST(InterpreterTensorsCapacityTest,TestWithinHeadroom)829 TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) {
830   Interpreter interpreter;
831   ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity),
832             kTfLiteOk);
833   TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
834   registration.prepare = [](TfLiteContext* context, TfLiteNode* node) {
835     TfLiteTensor* first_tensor = context->tensors;
836 
837     int new_tensor_index;
838     context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom,
839                         &new_tensor_index);
840     EXPECT_EQ(first_tensor, context->tensors);
841     return kTfLiteOk;
842   };
843   ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
844                                               &registration),
845             kTfLiteOk);
846   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
847 }
848 
TEST(InterpreterTensorsCapacityTest,TestExceedHeadroom)849 TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) {
850   Interpreter interpreter;
851   ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity),
852             kTfLiteOk);
853   TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
854   registration.prepare = [](TfLiteContext* context, TfLiteNode* node) {
855     TfLiteTensor* first_tensor = context->tensors;
856 
857     int new_tensor_index;
858     context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom + 1,
859                         &new_tensor_index);
860     EXPECT_NE(first_tensor, context->tensors);
861     return kTfLiteOk;
862   };
863   ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
864                                               &registration),
865             kTfLiteOk);
866   ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
867 }
868 
869 struct TestExternalContext : public TfLiteExternalContext {
870   static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext;
871 
Gettflite::__anonea31be420211::TestExternalContext872   static TestExternalContext* Get(TfLiteContext* context) {
873     return reinterpret_cast<TestExternalContext*>(
874         context->GetExternalContext(context, kType));
875   }
876 
Settflite::__anonea31be420211::TestExternalContext877   static void Set(TfLiteContext* context, TestExternalContext* value) {
878     context->SetExternalContext(context, kType, value);
879   }
880 
881   int num_refreshes = 0;
882 };
883 
TEST_F(InterpreterTest,GetSetResetExternalContexts)884 TEST_F(InterpreterTest, GetSetResetExternalContexts) {
885   auto* context = GetInterpreterContext();
886 
887   TestExternalContext external_context;
888   external_context.Refresh = [](TfLiteContext* context) {
889     auto* ptr = TestExternalContext::Get(context);
890     if (ptr != nullptr) {
891       ++ptr->num_refreshes;
892     }
893     return kTfLiteOk;
894   };
895 
896   EXPECT_EQ(TestExternalContext::Get(context), nullptr);
897   interpreter_.SetNumThreads(4);
898 
899   TestExternalContext::Set(context, &external_context);
900   EXPECT_EQ(TestExternalContext::Get(context), &external_context);
901   interpreter_.SetNumThreads(4);
902   interpreter_.SetNumThreads(5);
903   EXPECT_EQ(external_context.num_refreshes, 2);
904 
905   TestExternalContext::Set(context, nullptr);
906   EXPECT_EQ(TestExternalContext::Get(context), nullptr);
907   interpreter_.SetNumThreads(4);
908 }
909 
910 // Test fixture that allows playing with execution plans. It creates a two
911 // node graph that can be executed in either [0,1] order or [1,0] order.
912 // The CopyOp records when it is invoked in the class member run_order_
913 // so we can test whether the execution plan was honored.
914 class TestExecutionPlan : public ::testing::Test {
915   // Encapsulates the node ids and provides them to a C primitive data type
916   // Allocatable with placement new, but never destructed, so make sure this
917   // doesn't own any heap allocated data. This is then is used as op local
918   // data to allow access to the test fixture data.
919   class CallReporting {
920    public:
CallReporting(int node_id,std::vector<int> * run_order)921     CallReporting(int node_id, std::vector<int>* run_order)
922         : node_id_(node_id), run_order_(run_order) {}
923 
Record()924     void Record() { run_order_->push_back(node_id_); }
925 
926    private:
927     // The node id for this particular node
928     int node_id_;
929     // A pointer to the global run-order
930     std::vector<int>* run_order_;
931   };
932 
933   // Build a kernel registration for an op that copies its one input
934   // to an output
CopyOpRegistration()935   TfLiteRegistration CopyOpRegistration() {
936     TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
937 
938     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
939       // Set output size to input size
940       TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
941       TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
942       TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
943       return context->ResizeTensor(context, tensor1, newSize);
944     };
945 
946     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
947       CallReporting* call_reporting =
948           reinterpret_cast<CallReporting*>(node->builtin_data);
949       // Copy input data to output data.
950       TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
951       TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
952       int num = a0->dims->data[0];
953       for (int i = 0; i < num; i++) {
954         a1->data.f[i] = a0->data.f[i];
955       }
956       call_reporting->Record();
957       return kTfLiteOk;
958     };
959     return reg;
960   }
961 
962   // Adds a copy node going from tensor `input` to output tensor `output`.
963   // Note, input is used as the node_id. Inject run_order as op accessible
964   // data. Note: this is a little strange of a way to do this, but it is
965   // using op functionality to avoid static global variables.
MakeCopyNode(int input,int output)966   void MakeCopyNode(int input, int output) {
967     // Ownership of call_reporting is taken by interpreter (malloc is used due
968     // to nodes being a C99 interface so free() is used).
969     TfLiteRegistration copy_op = CopyOpRegistration();
970     CallReporting* call_reporting_1 =
971         reinterpret_cast<CallReporting*>(malloc(sizeof(CallReporting)));
972     new (call_reporting_1) CallReporting(input, &run_order_);
973     ASSERT_EQ(interpreter_.AddNodeWithParameters(
974                   {0}, {2}, nullptr, 0,
975                   reinterpret_cast<void*>(call_reporting_1), &copy_op),
976               kTfLiteOk);
977     ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk);
978   }
979 
SetUp()980   void SetUp() final {
981     // Add two inputs and two outputs that don't depend on each other
982     ASSERT_EQ(interpreter_.AddTensors(4), kTfLiteOk);
983     interpreter_.SetInputs({0, 1});
984     interpreter_.SetOutputs({2, 3});
985     TfLiteQuantizationParams quantized;
986     for (int tensor_index = 0; tensor_index < 4; tensor_index++) {
987       ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
988                     tensor_index, kTfLiteFloat32, "", {3}, quantized),
989                 kTfLiteOk);
990     }
991 
992     // Define two copy functions that also use the user_data to report that
993     // they were called.
994     // i.e. tensor[2] = copy(tensor[0]); tensor[3] = copy(tensor[1]);
995     // thus we can reorder the two nodes arbitrary and still satisfy dependency
996     // order.
997     MakeCopyNode(0, 2);
998     MakeCopyNode(1, 3);
999 
1000     ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
1001   }
1002 
1003  protected:
1004   Interpreter interpreter_;
1005 
1006   // list of node_ids that were run
1007   std::vector<int> run_order_;
1008 };
1009 
TEST_F(TestExecutionPlan,DefaultExecutionPlan)1010 TEST_F(TestExecutionPlan, DefaultExecutionPlan) {
1011   // Check default order
1012   ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1013   ASSERT_EQ(run_order_, std::vector<int>({0, 1}));
1014 }
1015 
TEST_F(TestExecutionPlan,ReversedExecutionPlan)1016 TEST_F(TestExecutionPlan, ReversedExecutionPlan) {
1017   // Check reversed order
1018   interpreter_.SetExecutionPlan({1, 0});
1019   ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1020   ASSERT_EQ(run_order_, std::vector<int>({1, 0}));
1021 }
1022 
TEST_F(TestExecutionPlan,SubsetExecutionPlan)1023 TEST_F(TestExecutionPlan, SubsetExecutionPlan) {
1024   // Check running only node index 1
1025   interpreter_.SetExecutionPlan({1});
1026   ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1027   ASSERT_EQ(run_order_, std::vector<int>({1}));
1028 }
1029 
TEST_F(TestExecutionPlan,NullExecutionPlan)1030 TEST_F(TestExecutionPlan, NullExecutionPlan) {
1031   // Check nothing executed.
1032   interpreter_.SetExecutionPlan({});
1033   ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1034   ASSERT_EQ(run_order_, std::vector<int>());
1035 }
1036 
1037 // Build a kernel registration for an op that copies its one input
1038 // to an output
AddOpRegistration()1039 TfLiteRegistration AddOpRegistration() {
1040   TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1041 
1042   reg.custom_name = "my_add";
1043   reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
1044 
1045   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1046     // Set output size to input size
1047     TfLiteTensor* input1 = &context->tensors[node->inputs->data[0]];
1048     TfLiteTensor* input2 = &context->tensors[node->inputs->data[1]];
1049     TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
1050 
1051     TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
1052     for (int i = 0; i < input1->dims->size; ++i) {
1053       TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
1054     }
1055 
1056     TF_LITE_ENSURE_STATUS(context->ResizeTensor(
1057         context, output, TfLiteIntArrayCopy(input1->dims)));
1058     return kTfLiteOk;
1059   };
1060 
1061   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1062     // Copy input data to output data.
1063     TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
1064     TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]];
1065     TfLiteTensor* out = &context->tensors[node->outputs->data[0]];
1066     int num = a0->dims->data[0];
1067     for (int i = 0; i < num; i++) {
1068       out->data.f[i] = a0->data.f[i] + a1->data.f[i];
1069     }
1070     return kTfLiteOk;
1071   };
1072   return reg;
1073 }
1074 
1075 class TestDelegate : public ::testing::Test {
1076  protected:
SetUp()1077   void SetUp() override {
1078     interpreter_.reset(new Interpreter);
1079     interpreter_->AddTensors(5);
1080     interpreter_->SetInputs({0, 1});
1081     interpreter_->SetOutputs({3, 4});
1082     TfLiteQuantizationParams quant;
1083     interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
1084                                                quant);
1085     interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
1086                                                quant);
1087     interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
1088                                                quant);
1089     interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
1090                                                quant);
1091     interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
1092                                                quant);
1093     TfLiteRegistration reg = AddOpRegistration();
1094     interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
1095     interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
1096     interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
1097   }
1098 
TearDown()1099   void TearDown() override {
1100     // Interpreter relies on delegate_ to free the resources properly. Thus
1101     // the life cycle of delegate must be longer than interpreter.
1102     interpreter_.reset();
1103     delegate_.reset();
1104   }
1105 
1106   TfLiteBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle;
1107 
AllocateBufferHandle()1108   TfLiteBufferHandle AllocateBufferHandle() { return ++last_allocated_handle_; }
1109 
1110  protected:
1111   class SimpleDelegate {
1112    public:
1113     // Create a simple implementation of a TfLiteDelegate. We use the C++ class
1114     // SimpleDelegate and it can produce a handle TfLiteDelegate that is
1115     // value-copyable and compatible with TfLite.
SimpleDelegate(const std::vector<int> & nodes)1116     explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) {
1117       delegate_.Prepare = [](TfLiteContext* context,
1118                              TfLiteDelegate* delegate) -> TfLiteStatus {
1119         auto* simple = reinterpret_cast<SimpleDelegate*>(delegate->data_);
1120         TfLiteIntArray* nodes_to_separate =
1121             TfLiteIntArrayCreate(simple->nodes_.size());
1122         // Mark nodes that we want in TfLiteIntArray* structure.
1123         int index = 0;
1124         for (auto node_index : simple->nodes_) {
1125           nodes_to_separate->data[index++] = node_index;
1126           // make sure node is add
1127           TfLiteNode* node;
1128           TfLiteRegistration* reg;
1129           context->GetNodeAndRegistration(context, node_index, &node, &reg);
1130           TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
1131           TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
1132         }
1133         // Check that all nodes are available
1134         TfLiteIntArray* execution_plan;
1135         TF_LITE_ENSURE_STATUS(
1136             context->GetExecutionPlan(context, &execution_plan));
1137         for (int exec_index = 0; exec_index < execution_plan->size;
1138              exec_index++) {
1139           int node_index = execution_plan->data[exec_index];
1140           // Check that we are an identity map to start.
1141           TFLITE_CHECK_EQ(exec_index, node_index);
1142           TfLiteNode* node;
1143           TfLiteRegistration* reg;
1144           context->GetNodeAndRegistration(context, node_index, &node, &reg);
1145           TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
1146           TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
1147         }
1148 
1149         context->ReplaceNodeSubsetsWithDelegateKernels(
1150             context, FakeFusedRegistration(), nodes_to_separate, delegate);
1151         TfLiteIntArrayFree(nodes_to_separate);
1152         return kTfLiteOk;
1153       };
1154       delegate_.CopyToBufferHandle = [](TfLiteContext* context,
1155                                         TfLiteDelegate* delegate,
1156                                         TfLiteBufferHandle buffer_handle,
1157                                         TfLiteTensor* tensor) -> TfLiteStatus {
1158         // TODO(ycling): Implement tests to test buffer copying logic.
1159         return kTfLiteOk;
1160       };
1161       delegate_.CopyFromBufferHandle =
1162           [](TfLiteContext* context, TfLiteDelegate* delegate,
1163              TfLiteBufferHandle buffer_handle,
1164              TfLiteTensor* output) -> TfLiteStatus {
1165         // TODO(ycling): Implement tests to test buffer copying logic.
1166         return kTfLiteOk;
1167       };
1168       delegate_.FreeBufferHandle =
1169           [](TfLiteContext* context, TfLiteDelegate* delegate,
1170              TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; };
1171       // Store type-punned data SimpleDelegate structure.
1172       delegate_.data_ = reinterpret_cast<void*>(this);
1173       delegate_.flags = kTfLiteDelegateFlagsNone;
1174     }
1175 
FakeFusedRegistration()1176     static TfLiteRegistration FakeFusedRegistration() {
1177       TfLiteRegistration reg = {nullptr};
1178       reg.custom_name = "fake_fused_op";
1179       return reg;
1180     }
1181 
get_tf_lite_delegate()1182     TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; }
1183 
1184    private:
1185     std::vector<int> nodes_;
1186     TfLiteDelegate delegate_;
1187   };
1188   std::unique_ptr<Interpreter> interpreter_;
1189   std::unique_ptr<SimpleDelegate> delegate_;
1190 };
1191 
TEST_F(TestDelegate,BasicDelegate)1192 TEST_F(TestDelegate, BasicDelegate) {
1193   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1194   interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
1195 
1196   ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1197   int node = interpreter_->execution_plan()[0];
1198   const auto* node_and_reg = interpreter_->node_and_registration(node);
1199   EXPECT_EQ(node_and_reg->second.custom_name,
1200             SimpleDelegate::FakeFusedRegistration().custom_name);
1201 
1202   const TfLiteDelegateParams* params =
1203       reinterpret_cast<const TfLiteDelegateParams*>(
1204           node_and_reg->first.builtin_data);
1205   ASSERT_EQ(params->nodes_to_replace->size, 3);
1206   EXPECT_EQ(params->nodes_to_replace->data[0], 0);
1207   EXPECT_EQ(params->nodes_to_replace->data[1], 1);
1208   EXPECT_EQ(params->nodes_to_replace->data[2], 2);
1209 
1210   ASSERT_EQ(params->input_tensors->size, 2);
1211   EXPECT_EQ(params->input_tensors->data[0], 0);
1212   EXPECT_EQ(params->input_tensors->data[1], 1);
1213 
1214   ASSERT_EQ(params->output_tensors->size, 2);
1215   EXPECT_EQ(params->output_tensors->data[0], 3);
1216   EXPECT_EQ(params->output_tensors->data[1], 4);
1217 }
1218 
TEST_F(TestDelegate,StaticDelegateMakesGraphImmutable)1219 TEST_F(TestDelegate, StaticDelegateMakesGraphImmutable) {
1220   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1221   ASSERT_EQ(
1222       interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
1223       kTfLiteOk);
1224   ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1225 
1226   // As the delegate doesn't support dynamic resizing, further graph mutation is
1227   // prohibited.
1228   ASSERT_NE(interpreter_->ResizeInputTensor(0, {0}), kTfLiteOk);
1229   ASSERT_NE(
1230       interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
1231       kTfLiteOk);
1232 }
1233 
TEST_F(TestDelegate,ComplexDelegate)1234 TEST_F(TestDelegate, ComplexDelegate) {
1235   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2}));
1236   interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
1237 
1238   ASSERT_EQ(interpreter_->execution_plan().size(), 2);
1239   // 0th should be a non-delegated original op
1240   ASSERT_EQ(interpreter_->execution_plan()[0], 0);
1241   // 1st should be a new macro op (3) which didn't exist)
1242   ASSERT_EQ(interpreter_->execution_plan()[1], 3);
1243   const auto* node_and_reg = interpreter_->node_and_registration(3);
1244   ASSERT_EQ(node_and_reg->second.custom_name,
1245             SimpleDelegate::FakeFusedRegistration().custom_name);
1246 }
1247 
TEST_F(TestDelegate,SetBufferHandleToInput)1248 TEST_F(TestDelegate, SetBufferHandleToInput) {
1249   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1250   TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
1251   interpreter_->ModifyGraphWithDelegate(delegate);
1252 
1253   constexpr int kOutputTensorIndex = 0;
1254   TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
1255   ASSERT_EQ(tensor->delegate, nullptr);
1256   ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1257 
1258   TfLiteBufferHandle handle = AllocateBufferHandle();
1259   TfLiteStatus status =
1260       interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate);
1261   ASSERT_EQ(status, kTfLiteOk);
1262   EXPECT_EQ(tensor->delegate, delegate);
1263   EXPECT_EQ(tensor->buffer_handle, handle);
1264 }
1265 
TEST_F(TestDelegate,SetBufferHandleToOutput)1266 TEST_F(TestDelegate, SetBufferHandleToOutput) {
1267   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1268   TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
1269   interpreter_->ModifyGraphWithDelegate(delegate);
1270 
1271   constexpr int kOutputTensorIndex = 3;
1272   TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
1273   // Before setting the buffer handle, the tensor's `delegate` is already set
1274   // because it will be written by the delegate.
1275   ASSERT_EQ(tensor->delegate, delegate);
1276   ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1277 
1278   TfLiteBufferHandle handle = AllocateBufferHandle();
1279   TfLiteStatus status =
1280       interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate);
1281   ASSERT_EQ(status, kTfLiteOk);
1282   EXPECT_EQ(tensor->delegate, delegate);
1283   EXPECT_EQ(tensor->buffer_handle, handle);
1284 }
1285 
TEST_F(TestDelegate,SetInvalidHandleToTensor)1286 TEST_F(TestDelegate, SetInvalidHandleToTensor) {
1287   interpreter_->Invoke();
1288   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1289   TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
1290   interpreter_->ModifyGraphWithDelegate(delegate);
1291 
1292   SimpleDelegate another_simple_delegate({0, 1, 2});
1293 
1294   constexpr int kOutputTensorIndex = 3;
1295   TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
1296   // Before setting the buffer handle, the tensor's `delegate` is already set
1297   // because it will be written by the delegate.
1298   ASSERT_EQ(tensor->delegate, delegate);
1299   ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1300 
1301   TfLiteBufferHandle handle = AllocateBufferHandle();
1302   TfLiteStatus status = interpreter_->SetBufferHandle(
1303       kOutputTensorIndex, handle,
1304       another_simple_delegate.get_tf_lite_delegate());
1305   // Setting a buffer handle to a tensor with another delegate will fail.
1306   ASSERT_EQ(status, kTfLiteError);
1307   EXPECT_EQ(tensor->delegate, delegate);
1308   EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1309 }
1310 
TEST_F(TestDelegate,ResizeInputWithNonDynamicDelegateShouldFail)1311 TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) {
1312   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1313   ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteOk);
1314   ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 2}), kTfLiteOk);
1315   ASSERT_EQ(
1316       interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
1317       kTfLiteOk);
1318   ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError);
1319 }
1320 
1321 class TestDelegateWithDynamicTensors : public ::testing::Test {
1322  protected:
SetUp()1323   void SetUp() override {
1324     interpreter_.reset(new Interpreter);
1325 
1326     interpreter_->AddTensors(2);
1327     interpreter_->SetInputs({0});
1328     interpreter_->SetOutputs({1});
1329     TfLiteQuantizationParams quant;
1330     interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
1331                                                quant);
1332     interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
1333                                                quant);
1334     TfLiteRegistration reg = DynamicCopyOpRegistration();
1335     interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg);
1336 
1337     delegate_.Prepare = [](TfLiteContext* context,
1338                            TfLiteDelegate* delegate) -> TfLiteStatus {
1339       // In this test, the delegate replaces all the nodes if this function is
1340       // called.
1341       TfLiteIntArray* execution_plan;
1342       TF_LITE_ENSURE_STATUS(
1343           context->GetExecutionPlan(context, &execution_plan));
1344       context->ReplaceNodeSubsetsWithDelegateKernels(
1345           context, DelegateRegistration(), execution_plan, delegate);
1346       return kTfLiteOk;
1347     };
1348     delegate_.flags = kTfLiteDelegateFlagsNone;
1349   }
1350 
DynamicCopyOpRegistration()1351   static TfLiteRegistration DynamicCopyOpRegistration() {
1352     TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1353 
1354     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1355       TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
1356       SetTensorToDynamic(output);
1357       return kTfLiteOk;
1358     };
1359 
1360     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1361       // Not implemented since this isn't required in testing.
1362       return kTfLiteOk;
1363     };
1364     return reg;
1365   }
1366 
DelegateRegistration()1367   static TfLiteRegistration DelegateRegistration() {
1368     TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1369     return reg;
1370   }
1371 
1372   std::unique_ptr<Interpreter> interpreter_;
1373   TfLiteDelegate delegate_;
1374 };
1375 
TEST_F(TestDelegateWithDynamicTensors,DisallowDynamicTensors)1376 TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) {
1377   interpreter_->ModifyGraphWithDelegate(&delegate_);
1378 
1379   ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1380   // The interpreter should not call delegate's `Prepare` when dynamic tensors
1381   // exist. So the node ID isn't changed.
1382   ASSERT_EQ(interpreter_->execution_plan()[0], 0);
1383 }
1384 
TEST_F(TestDelegateWithDynamicTensors,AllowDynamicTensors)1385 TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) {
1386   delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
1387   interpreter_->ModifyGraphWithDelegate(&delegate_);
1388 
1389   ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1390   // The node should be replaced because dynamic tensors are allowed. Therefore
1391   // only node ID in the execution plan is changed from 0 to 1.
1392   ASSERT_EQ(interpreter_->execution_plan()[0], 1);
1393 }
1394 
TEST_F(TestDelegateWithDynamicTensors,ModifyGraphAfterAllocate)1395 TEST_F(TestDelegateWithDynamicTensors, ModifyGraphAfterAllocate) {
1396   // Trigger allocation *before* delegate application.
1397   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
1398 
1399   delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
1400   ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(&delegate_), kTfLiteOk);
1401   ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1402   ASSERT_EQ(interpreter_->execution_plan()[0], 1);
1403 
1404   // Allocation should still succeed.
1405   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
1406 }
1407 
TEST(TestDelegateOwnership,ProperlyDisposed)1408 TEST(TestDelegateOwnership, ProperlyDisposed) {
1409   struct TfLiteInterpreterOwnedDelegate : public TfLiteDelegate {
1410     TfLiteInterpreterOwnedDelegate(bool* destroyed, bool* prepared)
1411         : destroyed(destroyed), prepared(prepared) {
1412       flags = kTfLiteDelegateFlagsNone;
1413       Prepare = [](TfLiteContext*, TfLiteDelegate* delegate) -> TfLiteStatus {
1414         *static_cast<TfLiteInterpreterOwnedDelegate*>(delegate)->prepared =
1415             true;
1416         return kTfLiteOk;
1417       };
1418     }
1419     ~TfLiteInterpreterOwnedDelegate() { *destroyed = true; }
1420 
1421     bool* destroyed;
1422     bool* prepared;
1423   };
1424 
1425   // Construct a delegate with flags for indicating preparation/destruction.
1426   bool destroyed = false;
1427   bool prepared = false;
1428   std::unique_ptr<TfLiteInterpreterOwnedDelegate> delegate(
1429       new TfLiteInterpreterOwnedDelegate(&destroyed, &prepared));
1430   {
1431     // Create an interpreter and assemble a simple graph.
1432     Interpreter interpreter;
1433     TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
1434     ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
1435     ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
1436     ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
1437     ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
1438                                                 &registration),
1439               kTfLiteOk);
1440 
1441     // Pass delegate ownership to that interpreter.
1442     ASSERT_EQ(InterpreterTest::ModifyGraphWithDelegate(&interpreter,
1443                                                        std::move(delegate)),
1444               kTfLiteOk);
1445 
1446     // The delegate should be prepared as normal, and should be preserved.
1447     EXPECT_TRUE(prepared);
1448     EXPECT_FALSE(destroyed);
1449 
1450     // Interpreter interaction should not impact the delegate's validity.
1451     interpreter.AllocateTensors();
1452     interpreter.Invoke();
1453     EXPECT_FALSE(destroyed);
1454   }
1455 
1456   // Only after the interpreter is destroyed should the delegate be destroyed.
1457   EXPECT_TRUE(destroyed);
1458 }
1459 
1460 // CancellationData contains the data required to cancel a call to Invoke().
1461 struct CancellationData {
1462   bool is_cancelled = false;
1463 };
1464 
1465 // Indicates whether Invoke() has been cancelled based on the value of the
1466 // CancellationData object passed in.
CheckCancellation(void * data)1467 bool CheckCancellation(void* data) {
1468   CancellationData* cancellation_data =
1469       static_cast<struct CancellationData*>(data);
1470   return cancellation_data->is_cancelled;
1471 }
1472 
1473 static struct CancellationData cancellation_data_;
1474 
1475 // Test fixture to test cancellation within the Interpreter.
1476 class CancellationTest : public ::testing::Test {
1477  public:
Invoke()1478   TfLiteStatus Invoke() { return interpreter_.Invoke(); }
Cancel()1479   void Cancel() { cancellation_data_.is_cancelled = true; }
1480 
1481   // Adds an CancelOp with input tensor `input` and output tensor `output`.
MakeCancelNode(int input,int output)1482   void MakeCancelNode(int input, int output) {
1483     TfLiteRegistration op = CancelOpRegistration();
1484     ASSERT_EQ(interpreter_.AddNodeWithParameters({input}, {output}, nullptr, 0,
1485                                                  nullptr, &op),
1486               kTfLiteOk);
1487     ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk);
1488   }
1489 
1490   // Adds an OkOp with input tensor `input` and output tensor `output`.
MakeOkNode(int input,int output)1491   void MakeOkNode(int input, int output) {
1492     TfLiteRegistration op = OkOpRegistration();
1493     ASSERT_EQ(interpreter_.AddNodeWithParameters({input}, {output}, nullptr, 0,
1494                                                  nullptr, &op),
1495               kTfLiteOk);
1496     ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk);
1497   }
1498 
1499   Interpreter interpreter_;
1500 
1501  private:
1502   // Build the kernel registration for an op that cancels the operation.
CancelOpRegistration()1503   TfLiteRegistration CancelOpRegistration() {
1504     TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1505 
1506     // Set output size to the input size in CancelOp::Prepare(). Code exists to
1507     // have a framework in Prepare. The input and output tensors are not used.
1508     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1509       TfLiteTensor* in_tensor = &context->tensors[node->inputs->data[0]];
1510       TfLiteTensor* out_tensor = &context->tensors[node->outputs->data[0]];
1511       TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
1512       return context->ResizeTensor(context, out_tensor, new_size);
1513     };
1514 
1515     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1516       cancellation_data_.is_cancelled = true;
1517       return kTfLiteOk;
1518     };
1519     return reg;
1520   }
1521 
1522   // Build the kernel registration for an op that returns kTfLiteOk.
OkOpRegistration()1523   TfLiteRegistration OkOpRegistration() {
1524     TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1525 
1526     // Set output size to the input size in OkOp::Prepare(). Code exists to have
1527     // a framework in Prepare. The input and output tensors are not used.
1528     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1529       TfLiteTensor* in_tensor = &context->tensors[node->inputs->data[0]];
1530       TfLiteTensor* out_tensor = &context->tensors[node->outputs->data[0]];
1531       TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
1532       return context->ResizeTensor(context, out_tensor, new_size);
1533     };
1534 
1535     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1536       return kTfLiteOk;
1537     };
1538     return reg;
1539   }
1540 
SetUp()1541   void SetUp() final {
1542     cancellation_data_.is_cancelled = false;
1543 
1544     // Set up the interpreter. Create the input and output tensors.
1545     int num_tensors = 3;
1546     ASSERT_EQ(interpreter_.AddTensors(num_tensors), kTfLiteOk);
1547     interpreter_.SetInputs({0});
1548     interpreter_.SetOutputs({2});
1549     TfLiteQuantizationParams quantized;
1550     for (int tensor_index = 0; tensor_index < num_tensors; tensor_index++) {
1551       ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
1552                     tensor_index, kTfLiteFloat32, "", {3}, quantized),
1553                 kTfLiteOk);
1554     }
1555     interpreter_.SetCancellationFunction(&cancellation_data_,
1556                                          &CheckCancellation);
1557   }
1558 };
1559 
TEST_F(CancellationTest,CancelBeforeInvoke)1560 TEST_F(CancellationTest, CancelBeforeInvoke) {
1561   // Cancel prior to calling Invoke.
1562   CancellationTest::MakeOkNode(1, 2);
1563   ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
1564 
1565   CancellationTest::Cancel();
1566   TfLiteStatus invoke_error_code = CancellationTest::Invoke();
1567   ASSERT_EQ(invoke_error_code, kTfLiteError);
1568 }
1569 
TEST_F(CancellationTest,CancelDuringInvoke)1570 TEST_F(CancellationTest, CancelDuringInvoke) {
1571   // Tests a model which sets the cancel in order to test cancellation works
1572   // between ops.
1573   //
1574   // The first op will set the cancellation bit to true. The second op returns
1575   // `kTfLiteOk` if executed.
1576   CancellationTest::MakeCancelNode(0, 1);
1577   CancellationTest::MakeOkNode(1, 2);
1578   ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
1579 
1580   TfLiteStatus invoke_error_code = CancellationTest::Invoke();
1581   ASSERT_EQ(invoke_error_code, kTfLiteError);
1582 }
1583 
1584 }  // namespace
1585 }  // namespace tflite
1586 
main(int argc,char ** argv)1587 int main(int argc, char** argv) {
1588   ::tflite::LogToStderr();
1589   ::testing::InitGoogleTest(&argc, argv);
1590   return RUN_ALL_TESTS();
1591 }
1592