1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/interpreter.h"
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/core/api/error_reporter.h"
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23 #include "tensorflow/lite/string_util.h"
24 #include "tensorflow/lite/testing/util.h"
25
26 namespace tflite {
27
28 // InterpreterTest is a friend of Interpreter, so it can access context_.
29 class InterpreterTest : public ::testing::Test {
30 public:
31 template <typename Delegate>
ModifyGraphWithDelegate(Interpreter * interpreter,std::unique_ptr<Delegate> delegate)32 static TfLiteStatus ModifyGraphWithDelegate(
33 Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
34 Interpreter::TfLiteDelegatePtr tflite_delegate(
35 delegate.release(), [](TfLiteDelegate* delegate) {
36 delete reinterpret_cast<Delegate*>(delegate);
37 });
38 return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate));
39 }
40
41 protected:
GetInterpreterContext()42 TfLiteContext* GetInterpreterContext() { return interpreter_.context_; }
43
44 Interpreter interpreter_;
45 };
46
47 namespace ops {
48 namespace builtin {
49 TfLiteRegistration* Register_PADV2();
50 TfLiteRegistration* Register_NEG();
51 } // namespace builtin
52 } // namespace ops
53 namespace {
54
55 using ::testing::IsEmpty;
56
57 // Make an interpreter that has no tensors and no nodes
TEST(BasicInterpreter,ZeroInterpreter)58 TEST(BasicInterpreter, ZeroInterpreter) {
59 testing::internal::CaptureStderr();
60
61 Interpreter interpreter;
62 EXPECT_THAT(testing::internal::GetCapturedStderr(),
63 testing::HasSubstr("INFO: Initialized TensorFlow Lite runtime"));
64
65 interpreter.SetInputs({});
66 interpreter.SetOutputs({});
67 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
68 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
69
70 // Creating a new interpreter should not redundantly log runtime init.
71 testing::internal::CaptureStderr();
72 Interpreter interpreter2;
73 EXPECT_THAT(testing::internal::GetCapturedStderr(), IsEmpty());
74 }
75
76 // Test various error conditions.
TEST(BasicInterpreter,InvokeInvalidModel)77 TEST(BasicInterpreter, InvokeInvalidModel) {
78 Interpreter interpreter;
79 ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
80 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
81 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
82 }
83
TEST(BasicInterpreter,TestAllocateTensorsResetVariableTensors)84 TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) {
85 Interpreter interpreter;
86 int tensor_index;
87 ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk);
88 constexpr int kTensorSize = 16;
89 TfLiteQuantizationParams quant;
90 interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "",
91 {kTensorSize}, quant, true);
92 interpreter.SetVariables({tensor_index});
93 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
94 TfLiteTensor* tensor = interpreter.tensor(tensor_index);
95 // Ensure that variable tensors are reset to zero.
96 for (int i = 0; i < kTensorSize; ++i) {
97 ASSERT_EQ(tensor->data.f[i], 0.0f);
98 }
99 }
100
101 // Test size accessor functions.
TEST(BasicInterpreter,TestSizeFunctions)102 TEST(BasicInterpreter, TestSizeFunctions) {
103 Interpreter interpreter;
104 int base_index;
105 ASSERT_EQ(interpreter.nodes_size(), 0);
106 ASSERT_EQ(interpreter.tensors_size(), 0);
107 ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk);
108 ASSERT_EQ(interpreter.tensors_size(), 2);
109 ASSERT_EQ(base_index, 0);
110 ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk);
111 ASSERT_EQ(interpreter.tensors_size(), 5);
112 ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
113 ASSERT_EQ(interpreter.tensors_size(), 6);
114 ASSERT_EQ(base_index, 2);
115 }
116
117 // Test if invalid indices make a model inconsistent (and conversely if
118 // valid indices keep a model consistent).
TEST(BasicInterpreter,InconsistentModel)119 TEST(BasicInterpreter, InconsistentModel) {
120 // Invalid inputs
121 {
122 Interpreter interpreter;
123 ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk);
124 ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
125 ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
126 ASSERT_EQ(interpreter.inputs(), std::vector<int>());
127 }
128 // Invalid outputs
129 {
130 Interpreter interpreter;
131 ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk);
132 ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
133 ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
134 ASSERT_EQ(interpreter.outputs(), std::vector<int>());
135 }
136 // Invalid node inputs
137 {
138 Interpreter interpreter;
139 TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
140 ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr,
141 ®istration),
142 kTfLiteOk);
143 ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk);
144 ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
145 }
146 // Valid inputs and outputs and a node with valid inputs and outputs
147 {
148 Interpreter interpreter;
149 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
150 TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
151 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
152 ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
153 ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
154 ®istration),
155 kTfLiteOk);
156 }
157 }
158
159 // Make an interpreter that has one tensor but no ops
TEST(BasicInterpreter,CheckAllocate)160 TEST(BasicInterpreter, CheckAllocate) {
161 struct {
162 TfLiteType type;
163 size_t size;
164 } cases[] = {
165 {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
166 {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)},
167 {kTfLiteInt16, sizeof(int16_t)},
168 };
169
170 for (auto test : cases) {
171 Interpreter interpreter;
172 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
173 interpreter.SetInputs({0, 1});
174 interpreter.SetOutputs({});
175 TfLiteQuantizationParams quant;
176
177 interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant);
178 interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant);
179 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
180 ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size);
181 ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
182 ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size);
183 ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
184 }
185 }
186
TEST(BasicInterpreter,CheckQuantization)187 TEST(BasicInterpreter, CheckQuantization) {
188 Interpreter interpreter;
189 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
190 interpreter.SetInputs({0, 1});
191 interpreter.SetOutputs({});
192 TfLiteType tensor_type = kTfLiteInt8;
193 const uint8_t int8s[] = {3, 4};
194 float scale = 0.5f;
195 int32_t zero_point = 12;
196
197 TfLiteQuantization rw_quantization;
198 rw_quantization.type = kTfLiteAffineQuantization;
199 auto* rw_affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
200 malloc(sizeof(TfLiteAffineQuantization)));
201 rw_affine_quantization->scale = TfLiteFloatArrayCreate(1);
202 rw_affine_quantization->zero_point = TfLiteIntArrayCreate(1);
203 rw_affine_quantization->scale->data[0] = scale;
204 rw_affine_quantization->zero_point->data[0] = zero_point;
205 rw_quantization.params = rw_affine_quantization;
206
207 TfLiteQuantization ro_quantization;
208 ro_quantization.type = kTfLiteAffineQuantization;
209 auto* ro_affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
210 malloc(sizeof(TfLiteAffineQuantization)));
211 ro_affine_quantization->scale = TfLiteFloatArrayCreate(1);
212 ro_affine_quantization->zero_point = TfLiteIntArrayCreate(1);
213 ro_affine_quantization->scale->data[0] = scale;
214 ro_affine_quantization->zero_point->data[0] = zero_point;
215 ro_quantization.params = ro_affine_quantization;
216
217 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, tensor_type, "", {3},
218 rw_quantization),
219 kTfLiteOk);
220 ASSERT_EQ(interpreter.SetTensorParametersReadOnly(
221 1, tensor_type, "", {2}, ro_quantization,
222 reinterpret_cast<const char*>(int8s), 2),
223 kTfLiteOk);
224 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
225 // Check that the legacy scale and zero_point are set correctly.
226 ASSERT_EQ(interpreter.tensor(0)->params.scale, scale);
227 ASSERT_EQ(interpreter.tensor(0)->params.zero_point, zero_point);
228 ASSERT_EQ(interpreter.tensor(0)->quantization.type, rw_quantization.type);
229 ASSERT_EQ(interpreter.tensor(1)->params.scale, scale);
230 ASSERT_EQ(interpreter.tensor(1)->params.zero_point, zero_point);
231 ASSERT_EQ(interpreter.tensor(1)->quantization.type, ro_quantization.type);
232 }
233
TEST(BasicInterpreter,CheckResize)234 TEST(BasicInterpreter, CheckResize) {
235 const float floats[] = {-3., -4.};
236 const int32_t int32s[] = {-3, -4};
237 const uint8_t uint8s[] = {3, 4};
238 const int64_t int64s[] = {6, -7};
239 const int16_t int16s[] = {8, -9};
240
241 struct {
242 TfLiteType type;
243 size_t size;
244 const char* array;
245 } cases[] = {
246 {kTfLiteFloat32, sizeof(float), reinterpret_cast<const char*>(floats)},
247 {kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
248 {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
249 {kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
250 {kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)},
251 };
252
253 for (auto test : cases) {
254 Interpreter interpreter;
255
256 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
257 interpreter.SetInputs({0, 1});
258 interpreter.SetOutputs({});
259 TfLiteQuantizationParams quant;
260
261 ASSERT_EQ(
262 interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant),
263 kTfLiteOk);
264 ASSERT_EQ(interpreter.SetTensorParametersReadOnly(
265 1, test.type, "", {2}, quant, test.array, 2 * test.size),
266 kTfLiteOk);
267 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
268 ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk);
269 // Resizing a mmapped tensor is not allowed and should produce error.
270 ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk);
271 // Set the tensor to be mmapped but with a buffer size that is insufficient
272 // to match the dimensionality.
273 ASSERT_NE(interpreter.SetTensorParametersReadOnly(
274 1, test.type, "", {2}, quant, test.array, 1 * test.size),
275 kTfLiteOk);
276 // Allocating should work since we should have our last correct array
277 // values in place.
278 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
279 }
280 }
281
TEST(BasicInterpreter,CheckAlignment)282 TEST(BasicInterpreter, CheckAlignment) {
283 struct {
284 TfLiteType type;
285 } cases[] = {
286 {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8},
287 {kTfLiteInt64}, {kTfLiteInt16},
288 };
289
290 for (auto test : cases) {
291 Interpreter interpreter;
292
293 ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk);
294
295 for (int i = 0; i < 4; i++) {
296 TfLiteQuantizationParams quant;
297 interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1},
298 quant);
299 }
300 interpreter.AllocateTensors();
301 for (int i = 0; i < 4; i++) {
302 const TfLiteTensor& tensor = *interpreter.tensor(i);
303 ASSERT_EQ(reinterpret_cast<intptr_t>(tensor.data.raw) % 4, 0);
304 }
305 }
306 }
307
TEST(BasicInterpreter,CheckArenaAllocation)308 TEST(BasicInterpreter, CheckArenaAllocation) {
309 Interpreter interpreter;
310 ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk);
311
312 TfLiteQuantizationParams quant;
313 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
314
315 std::vector<int> sizes{2048, 4096, 1023, 2047, 1021,
316 2047, 1023, 2046, 0, 2048};
317 for (int i = 0; i < sizes.size(); ++i) {
318 interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]},
319 quant);
320 }
321 interpreter.SetInputs({0, 1});
322 interpreter.SetOutputs({9, 4});
323 interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, ®);
324 interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, ®);
325 interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, ®);
326 interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, ®);
327 interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, ®);
328
329 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
330
331 ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw);
332 ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw);
333 ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw);
334 ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw);
335 ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw);
336 ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw);
337 ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw);
338 // #7 is the one with the largest pointer.
339 ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr);
340 ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
341 }
342
TEST(BasicInterpreter,BufferAccess)343 TEST(BasicInterpreter, BufferAccess) {
344 Interpreter interpreter;
345 ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
346 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
347
348 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
349 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
350 kTfLiteOk);
351 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
352 // Verify we get a valid pointer.r
353 ASSERT_NE(interpreter.typed_tensor<float>(0), nullptr);
354 // Verify incorrect pointer will not returned.
355 ASSERT_EQ(interpreter.typed_tensor<int>(0), nullptr);
356 // Verify that raw c interface ptr matches safe interface.
357 ASSERT_EQ(interpreter.typed_tensor<float>(0), interpreter.tensor(0)->data.f);
358 }
359
TEST(BasicInterpreter,NoOpInterpreter)360 TEST(BasicInterpreter, NoOpInterpreter) {
361 Interpreter interpreter;
362 ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
363 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
364 ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
365
366 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
367 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
368 kTfLiteOk);
369
370 ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}),
371 kTfLiteOk);
372 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
373 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
374 }
375
TEST(BasicInterpreter,RedundantAllocateTensors)376 TEST(BasicInterpreter, RedundantAllocateTensors) {
377 Interpreter interpreter;
378 ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
379 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
380
381 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
382 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
383 kTfLiteOk);
384
385 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
386 const auto data_raw = interpreter.tensor(0)->data.raw;
387 ASSERT_NE(data_raw, nullptr);
388
389 // A redundant allocation request should have no impact.
390 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
391 ASSERT_EQ(interpreter.tensor(0)->data.raw, data_raw);
392 }
393
TEST(BasicInterpreter,RedundantAllocateTensorsWithDynamicInputs)394 TEST(BasicInterpreter, RedundantAllocateTensorsWithDynamicInputs) {
395 Interpreter interpreter;
396 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
397 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
398 interpreter.SetInputs({0});
399 interpreter.SetOutputs({1});
400 interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®);
401
402 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
403 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
404 kTfLiteOk);
405 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
406 1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
407 kTfLiteOk);
408
409 // Configure the input tensor as dynamic.
410 interpreter.tensor(0)->data.raw = nullptr;
411 interpreter.tensor(0)->allocation_type = kTfLiteDynamic;
412
413 ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}),
414 kTfLiteOk);
415 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
416 ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
417
418 // Reset the output tensor's buffer.
419 interpreter.tensor(1)->data.raw = nullptr;
420
421 // A redundant allocation request should be honored, as the input tensor
422 // was marked dynamic.
423 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
424 ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
425 }
426
TEST(BasicInterpreter,ResizingTensors)427 TEST(BasicInterpreter, ResizingTensors) {
428 Interpreter interpreter;
429 ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
430 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
431 ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
432
433 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
434 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
435 kTfLiteOk);
436
437 int t = interpreter.inputs()[0];
438 TfLiteTensor* tensor = interpreter.tensor(t);
439
440 ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
441 EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
442 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
443
444 tensor->data.f[5] = 0.123f;
445
446 // Changing from kTfLiteArenaRw to kTfLiteDynamic is quite complicate: we need
447 // to unset data.raw, otherwise Realloc will try to free that memory.
448 tensor->data.raw = nullptr;
449 tensor->allocation_type = kTfLiteDynamic;
450
451 ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 4}), kTfLiteOk);
452 EXPECT_EQ(tensor->bytes, 8 * sizeof(float));
453 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
454
455 ASSERT_EQ(interpreter.ResizeInputTensor(t, {}), kTfLiteOk);
456 EXPECT_EQ(tensor->bytes, 1 * sizeof(float));
457 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
458
459 ASSERT_EQ(interpreter.ResizeInputTensor(t, {0}), kTfLiteOk);
460 EXPECT_EQ(tensor->bytes, 0);
461 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
462
463 ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 0}), kTfLiteOk);
464 EXPECT_EQ(tensor->bytes, 0);
465 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
466
467 // TODO(ahentz): We shouldn't have to force reallocation, but
468 // ResizeInputTensor doesn't realloc dynamic tensors. Also note that
469 // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op.
470 TfLiteTensorRealloc(9 * sizeof(float), tensor);
471 tensor->data.f[7] = 0.123f;
472
473 ASSERT_EQ(interpreter.ResizeInputTensor(t, {2, 2, 4}), kTfLiteOk);
474 EXPECT_EQ(tensor->bytes, 16 * sizeof(float));
475 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
476
477 // TODO(ahentz): We shouldn't have to force reallocation, but
478 // ResizeInputTensor doesn't realloc dynamic tensors. Also note that
479 // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op.
480 TfLiteTensorRealloc(17 * sizeof(float), tensor);
481 tensor->data.f[15] = 0.123f;
482 }
483
TEST(BasicInterpreter,NoopResizingTensors)484 TEST(BasicInterpreter, NoopResizingTensors) {
485 Interpreter interpreter;
486 ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
487 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
488 ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
489
490 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
491 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
492 kTfLiteOk);
493
494 int t = interpreter.inputs()[0];
495 TfLiteTensor* tensor = interpreter.tensor(t);
496
497 ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
498 EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
499 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
500 tensor->data.f[5] = 0.123f;
501
502 // Resizing to the same size should not trigger re-allocation.
503 ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
504 EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
505 ASSERT_NE(tensor->data.raw, nullptr);
506 ASSERT_EQ(tensor->data.f[5], 0.123f);
507
508 // Explicitly allocating should be a no-op, as no resize was performed.
509 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
510 EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
511 ASSERT_NE(tensor->data.raw, nullptr);
512 ASSERT_EQ(tensor->data.f[5], 0.123f);
513 }
514
TEST(BasicInterpreter,OneOpInterpreter)515 TEST(BasicInterpreter, OneOpInterpreter) {
516 Interpreter interpreter;
517 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
518 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
519 ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
520
521 TfLiteQuantizationParams quantized;
522 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1",
523 {3}, quantized),
524 kTfLiteOk);
525 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0",
526 {3}, quantized),
527 kTfLiteOk);
528
529 ASSERT_EQ(interpreter.GetInputName(0), "in1");
530 ASSERT_EQ(interpreter.GetOutputName(0), "out0");
531
532 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
533 reg.init = [](TfLiteContext* context, const char*, size_t) -> void* {
534 auto* first_new_tensor = new int;
535 context->AddTensors(context, 2, first_new_tensor);
536 return first_new_tensor;
537 };
538 reg.free = [](TfLiteContext* context, void* buffer) {
539 delete reinterpret_cast<int*>(buffer);
540 };
541 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
542 auto* first_new_tensor = reinterpret_cast<int*>(node->user_data);
543
544 TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
545 TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
546
547 TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
548 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize));
549
550 TfLiteIntArrayFree(node->temporaries);
551 node->temporaries = TfLiteIntArrayCreate(2);
552 for (int i = 0; i < 2; ++i) {
553 node->temporaries->data[i] = *(first_new_tensor) + i;
554 }
555
556 auto setup_temporary = [&](int id) {
557 TfLiteTensor* tmp = &context->tensors[id];
558 tmp->type = kTfLiteFloat32;
559 tmp->allocation_type = kTfLiteArenaRw;
560 return context->ResizeTensor(context, tmp,
561 TfLiteIntArrayCopy(tensor0->dims));
562 };
563 TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0]));
564 TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1]));
565
566 return kTfLiteOk;
567 };
568 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
569 TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
570
571 auto populate = [&](int id) {
572 TfLiteTensor* t = &context->tensors[id];
573 int num = a0->dims->data[0];
574 for (int i = 0; i < num; i++) {
575 t->data.f[i] = a0->data.f[i];
576 }
577 };
578
579 populate(node->outputs->data[0]);
580 populate(node->temporaries->data[0]);
581 populate(node->temporaries->data[1]);
582 return kTfLiteOk;
583 };
584 ASSERT_EQ(
585 interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®),
586 kTfLiteOk);
587 ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
588 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
589
590 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
591 }
592
593 // Forcefully divides tensor allocation in three steps: one before invocation
594 // and two more at invocation time. This happens because we use string tensors
595 // and their sizes can't be determined until invocation time.
TEST(BasicInterpreter,ThreeStepAllocate)596 TEST(BasicInterpreter, ThreeStepAllocate) {
597 Interpreter interpreter;
598 ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk);
599 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
600 ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk);
601
602 TfLiteQuantizationParams quantized;
603 char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'};
604 // Read only string tensor.
605 ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1},
606 quantized, data, 15),
607 kTfLiteOk);
608 // Read-write string tensor.
609 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1},
610 quantized),
611 kTfLiteOk);
612 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1},
613 quantized),
614 kTfLiteOk);
615 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1},
616 quantized),
617 kTfLiteOk);
618 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1},
619 quantized),
620 kTfLiteOk);
621
622 // String-in String-out node.
623 TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
624 reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
625 TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
626 TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
627 DynamicBuffer buf;
628 StringRef str_ref = GetString(input, 0);
629 buf.AddString(str_ref);
630 buf.WriteToTensorAsVector(output);
631 return kTfLiteOk;
632 };
633
634 // String-in Int-out node.
635 TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr};
636 reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) {
637 TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
638 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
639 outputSize->data[0] = 1;
640 return context->ResizeTensor(context, output, outputSize);
641 };
642 reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) {
643 TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
644 TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
645 a1->data.i32[0] = a0->bytes;
646 return kTfLiteOk;
647 };
648
649 ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
650 ®_copy),
651 kTfLiteOk);
652 ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr,
653 ®_len),
654 kTfLiteOk);
655 ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr,
656 ®_copy),
657 kTfLiteOk);
658 ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr,
659 ®_len),
660 kTfLiteOk);
661
662 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
663 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
664
665 ASSERT_EQ(interpreter.tensor(0)->bytes, 15);
666 ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr);
667 ASSERT_EQ(interpreter.tensor(1)->bytes, 15);
668 ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
669 ASSERT_EQ(interpreter.tensor(3)->bytes, 15);
670 ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr);
671 ASSERT_EQ(interpreter.tensor(2)->bytes, 4);
672 ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15);
673 ASSERT_EQ(interpreter.tensor(4)->bytes, 4);
674 ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15);
675 }
676
TEST(BasicInterpreter,AllocateTwice)677 TEST(BasicInterpreter, AllocateTwice) {
678 Interpreter interpreter;
679 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
680 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
681 ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
682
683 TfLiteQuantizationParams quantized;
684 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
685 quantized),
686 kTfLiteOk);
687 ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
688 quantized),
689 kTfLiteOk);
690
691 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
692 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
693 TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
694 TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
695 TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
696 return context->ResizeTensor(context, tensor1, newSize);
697 };
698 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
699 TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
700 TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
701 int num = a0->dims->data[0];
702 for (int i = 0; i < num; i++) {
703 a1->data.f[i] = a0->data.f[i];
704 }
705 return kTfLiteOk;
706 };
707 ASSERT_EQ(
708 interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®),
709 kTfLiteOk);
710 ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk);
711 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
712 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
713 char* old_tensor0_ptr = interpreter.tensor(0)->data.raw;
714 char* old_tensor1_ptr = interpreter.tensor(1)->data.raw;
715
716 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
717 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
718 ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw);
719 ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw);
720 }
721
TEST(BasicInterpreter,TestNullErrorReporter)722 TEST(BasicInterpreter, TestNullErrorReporter) {
723 TestErrorReporter reporter;
724 Interpreter interpreter;
725 }
726
TEST(BasicInterpreter,TestCustomErrorReporter)727 TEST(BasicInterpreter, TestCustomErrorReporter) {
728 TestErrorReporter reporter;
729 Interpreter interpreter(&reporter);
730 ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
731 ASSERT_EQ(reporter.error_messages(),
732 "Invoke called on model that is not ready.");
733 ASSERT_EQ(reporter.num_calls(), 1);
734 }
735
TEST(BasicInterpreter,TestUnsupportedDelegateFunctions)736 TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) {
737 Interpreter interpreter;
738 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
739 TfLiteRegistration registration = {
740 .init = nullptr, .free = nullptr, .prepare = nullptr, .invoke = nullptr};
741 // These functions are only supported inside Delegate's Prepare function.
742 // The test verifies that these functions returns `kTfLiteError`, but not
743 // `kTfLiteOk` or just crashes.
744 registration.prepare = [](TfLiteContext* context, TfLiteNode* node) {
745 {
746 TfLiteIntArray* execution_plan;
747 EXPECT_EQ(context->GetExecutionPlan(context, &execution_plan),
748 kTfLiteError);
749 }
750 {
751 TfLiteNode* node;
752 TfLiteRegistration* registration;
753 EXPECT_EQ(
754 context->GetNodeAndRegistration(context, 0, &node, ®istration),
755 kTfLiteError);
756 }
757 {
758 TfLiteRegistration delegate_registration = {nullptr, nullptr, nullptr,
759 nullptr};
760 TfLiteIntArray nodes_to_replace;
761 nodes_to_replace.size = 0;
762 EXPECT_EQ(context->ReplaceNodeSubsetsWithDelegateKernels(
763 context, delegate_registration, &nodes_to_replace, nullptr),
764 kTfLiteError);
765 }
766 return kTfLiteError;
767 };
768 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
769 ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
770 ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
771 ®istration),
772 kTfLiteOk);
773 EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError);
774 }
775
TEST(BasicInterpreter,DynamicTensorsResizeDescendants)776 TEST(BasicInterpreter, DynamicTensorsResizeDescendants) {
777 // Assemble a graph with a node that has dynamically sized output (via the
778 // pad op), followed by a node with a standard element-wise op (negate).
779 Interpreter interpreter;
780 interpreter.AddTensors(4);
781 interpreter.SetInputs({0, 1});
782 interpreter.SetOutputs({3});
783 TfLiteQuantizationParams quant;
784 interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1},
785 quant);
786 interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant);
787 interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant);
788 interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant);
789
790 TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2();
791 TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG();
792 interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op);
793 interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op);
794 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
795
796 // Configure [[2,2],[4,4]] padding and execute the graph.
797 interpreter.typed_tensor<int>(1)[0] = 2;
798 interpreter.typed_tensor<int>(1)[1] = 2;
799 interpreter.typed_tensor<int>(1)[2] = 2;
800 interpreter.typed_tensor<int>(1)[3] = 2;
801 interpreter.typed_tensor<int>(1)[4] = 0;
802 interpreter.typed_tensor<int>(1)[5] = 0;
803 interpreter.typed_tensor<int>(1)[6] = 0;
804 interpreter.typed_tensor<int>(1)[7] = 0;
805 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
806
807 // Both the output and intermediate tensor sizes should reflect the output
808 // from the dynamic pad operation.
809 ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6);
810 ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6);
811
812 // Now configure [[4,4],[6,6]] padding and execute the graph.
813 interpreter.typed_tensor<int>(1)[0] = 4;
814 interpreter.typed_tensor<int>(1)[1] = 4;
815 interpreter.typed_tensor<int>(1)[2] = 6;
816 interpreter.typed_tensor<int>(1)[3] = 6;
817 interpreter.typed_tensor<int>(1)[4] = 0;
818 interpreter.typed_tensor<int>(1)[5] = 0;
819 interpreter.typed_tensor<int>(1)[6] = 0;
820 interpreter.typed_tensor<int>(1)[7] = 0;
821 ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
822
823 // Again, the output and intermediate tensor sizes should reflect the *new*
824 // resize from the latest pad operation.
825 ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14);
826 ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14);
827 }
828
TEST(InterpreterTensorsCapacityTest,TestWithinHeadroom)829 TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) {
830 Interpreter interpreter;
831 ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity),
832 kTfLiteOk);
833 TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
834 registration.prepare = [](TfLiteContext* context, TfLiteNode* node) {
835 TfLiteTensor* first_tensor = context->tensors;
836
837 int new_tensor_index;
838 context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom,
839 &new_tensor_index);
840 EXPECT_EQ(first_tensor, context->tensors);
841 return kTfLiteOk;
842 };
843 ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
844 ®istration),
845 kTfLiteOk);
846 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
847 }
848
TEST(InterpreterTensorsCapacityTest,TestExceedHeadroom)849 TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) {
850 Interpreter interpreter;
851 ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity),
852 kTfLiteOk);
853 TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
854 registration.prepare = [](TfLiteContext* context, TfLiteNode* node) {
855 TfLiteTensor* first_tensor = context->tensors;
856
857 int new_tensor_index;
858 context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom + 1,
859 &new_tensor_index);
860 EXPECT_NE(first_tensor, context->tensors);
861 return kTfLiteOk;
862 };
863 ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
864 ®istration),
865 kTfLiteOk);
866 ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
867 }
868
869 struct TestExternalContext : public TfLiteExternalContext {
870 static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext;
871
Gettflite::__anonea31be420211::TestExternalContext872 static TestExternalContext* Get(TfLiteContext* context) {
873 return reinterpret_cast<TestExternalContext*>(
874 context->GetExternalContext(context, kType));
875 }
876
Settflite::__anonea31be420211::TestExternalContext877 static void Set(TfLiteContext* context, TestExternalContext* value) {
878 context->SetExternalContext(context, kType, value);
879 }
880
881 int num_refreshes = 0;
882 };
883
TEST_F(InterpreterTest,GetSetResetExternalContexts)884 TEST_F(InterpreterTest, GetSetResetExternalContexts) {
885 auto* context = GetInterpreterContext();
886
887 TestExternalContext external_context;
888 external_context.Refresh = [](TfLiteContext* context) {
889 auto* ptr = TestExternalContext::Get(context);
890 if (ptr != nullptr) {
891 ++ptr->num_refreshes;
892 }
893 return kTfLiteOk;
894 };
895
896 EXPECT_EQ(TestExternalContext::Get(context), nullptr);
897 interpreter_.SetNumThreads(4);
898
899 TestExternalContext::Set(context, &external_context);
900 EXPECT_EQ(TestExternalContext::Get(context), &external_context);
901 interpreter_.SetNumThreads(4);
902 interpreter_.SetNumThreads(5);
903 EXPECT_EQ(external_context.num_refreshes, 2);
904
905 TestExternalContext::Set(context, nullptr);
906 EXPECT_EQ(TestExternalContext::Get(context), nullptr);
907 interpreter_.SetNumThreads(4);
908 }
909
910 // Test fixture that allows playing with execution plans. It creates a two
911 // node graph that can be executed in either [0,1] order or [1,0] order.
912 // The CopyOp records when it is invoked in the class member run_order_
913 // so we can test whether the execution plan was honored.
914 class TestExecutionPlan : public ::testing::Test {
915 // Encapsulates the node ids and provides them to a C primitive data type
916 // Allocatable with placement new, but never destructed, so make sure this
917 // doesn't own any heap allocated data. This is then is used as op local
918 // data to allow access to the test fixture data.
919 class CallReporting {
920 public:
CallReporting(int node_id,std::vector<int> * run_order)921 CallReporting(int node_id, std::vector<int>* run_order)
922 : node_id_(node_id), run_order_(run_order) {}
923
Record()924 void Record() { run_order_->push_back(node_id_); }
925
926 private:
927 // The node id for this particular node
928 int node_id_;
929 // A pointer to the global run-order
930 std::vector<int>* run_order_;
931 };
932
933 // Build a kernel registration for an op that copies its one input
934 // to an output
CopyOpRegistration()935 TfLiteRegistration CopyOpRegistration() {
936 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
937
938 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
939 // Set output size to input size
940 TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
941 TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
942 TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
943 return context->ResizeTensor(context, tensor1, newSize);
944 };
945
946 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
947 CallReporting* call_reporting =
948 reinterpret_cast<CallReporting*>(node->builtin_data);
949 // Copy input data to output data.
950 TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
951 TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
952 int num = a0->dims->data[0];
953 for (int i = 0; i < num; i++) {
954 a1->data.f[i] = a0->data.f[i];
955 }
956 call_reporting->Record();
957 return kTfLiteOk;
958 };
959 return reg;
960 }
961
962 // Adds a copy node going from tensor `input` to output tensor `output`.
963 // Note, input is used as the node_id. Inject run_order as op accessible
964 // data. Note: this is a little strange of a way to do this, but it is
965 // using op functionality to avoid static global variables.
MakeCopyNode(int input,int output)966 void MakeCopyNode(int input, int output) {
967 // Ownership of call_reporting is taken by interpreter (malloc is used due
968 // to nodes being a C99 interface so free() is used).
969 TfLiteRegistration copy_op = CopyOpRegistration();
970 CallReporting* call_reporting_1 =
971 reinterpret_cast<CallReporting*>(malloc(sizeof(CallReporting)));
972 new (call_reporting_1) CallReporting(input, &run_order_);
973 ASSERT_EQ(interpreter_.AddNodeWithParameters(
974 {0}, {2}, nullptr, 0,
975 reinterpret_cast<void*>(call_reporting_1), ©_op),
976 kTfLiteOk);
977 ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk);
978 }
979
SetUp()980 void SetUp() final {
981 // Add two inputs and two outputs that don't depend on each other
982 ASSERT_EQ(interpreter_.AddTensors(4), kTfLiteOk);
983 interpreter_.SetInputs({0, 1});
984 interpreter_.SetOutputs({2, 3});
985 TfLiteQuantizationParams quantized;
986 for (int tensor_index = 0; tensor_index < 4; tensor_index++) {
987 ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
988 tensor_index, kTfLiteFloat32, "", {3}, quantized),
989 kTfLiteOk);
990 }
991
992 // Define two copy functions that also use the user_data to report that
993 // they were called.
994 // i.e. tensor[2] = copy(tensor[0]); tensor[3] = copy(tensor[1]);
995 // thus we can reorder the two nodes arbitrary and still satisfy dependency
996 // order.
997 MakeCopyNode(0, 2);
998 MakeCopyNode(1, 3);
999
1000 ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
1001 }
1002
1003 protected:
1004 Interpreter interpreter_;
1005
1006 // list of node_ids that were run
1007 std::vector<int> run_order_;
1008 };
1009
TEST_F(TestExecutionPlan,DefaultExecutionPlan)1010 TEST_F(TestExecutionPlan, DefaultExecutionPlan) {
1011 // Check default order
1012 ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1013 ASSERT_EQ(run_order_, std::vector<int>({0, 1}));
1014 }
1015
TEST_F(TestExecutionPlan,ReversedExecutionPlan)1016 TEST_F(TestExecutionPlan, ReversedExecutionPlan) {
1017 // Check reversed order
1018 interpreter_.SetExecutionPlan({1, 0});
1019 ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1020 ASSERT_EQ(run_order_, std::vector<int>({1, 0}));
1021 }
1022
TEST_F(TestExecutionPlan,SubsetExecutionPlan)1023 TEST_F(TestExecutionPlan, SubsetExecutionPlan) {
1024 // Check running only node index 1
1025 interpreter_.SetExecutionPlan({1});
1026 ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1027 ASSERT_EQ(run_order_, std::vector<int>({1}));
1028 }
1029
TEST_F(TestExecutionPlan,NullExecutionPlan)1030 TEST_F(TestExecutionPlan, NullExecutionPlan) {
1031 // Check nothing executed.
1032 interpreter_.SetExecutionPlan({});
1033 ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
1034 ASSERT_EQ(run_order_, std::vector<int>());
1035 }
1036
1037 // Build a kernel registration for an op that copies its one input
1038 // to an output
AddOpRegistration()1039 TfLiteRegistration AddOpRegistration() {
1040 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1041
1042 reg.custom_name = "my_add";
1043 reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
1044
1045 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1046 // Set output size to input size
1047 TfLiteTensor* input1 = &context->tensors[node->inputs->data[0]];
1048 TfLiteTensor* input2 = &context->tensors[node->inputs->data[1]];
1049 TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
1050
1051 TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
1052 for (int i = 0; i < input1->dims->size; ++i) {
1053 TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
1054 }
1055
1056 TF_LITE_ENSURE_STATUS(context->ResizeTensor(
1057 context, output, TfLiteIntArrayCopy(input1->dims)));
1058 return kTfLiteOk;
1059 };
1060
1061 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1062 // Copy input data to output data.
1063 TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
1064 TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]];
1065 TfLiteTensor* out = &context->tensors[node->outputs->data[0]];
1066 int num = a0->dims->data[0];
1067 for (int i = 0; i < num; i++) {
1068 out->data.f[i] = a0->data.f[i] + a1->data.f[i];
1069 }
1070 return kTfLiteOk;
1071 };
1072 return reg;
1073 }
1074
1075 class TestDelegate : public ::testing::Test {
1076 protected:
SetUp()1077 void SetUp() override {
1078 interpreter_.reset(new Interpreter);
1079 interpreter_->AddTensors(5);
1080 interpreter_->SetInputs({0, 1});
1081 interpreter_->SetOutputs({3, 4});
1082 TfLiteQuantizationParams quant;
1083 interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
1084 quant);
1085 interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
1086 quant);
1087 interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
1088 quant);
1089 interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
1090 quant);
1091 interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
1092 quant);
1093 TfLiteRegistration reg = AddOpRegistration();
1094 interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®);
1095 interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®);
1096 interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®);
1097 }
1098
TearDown()1099 void TearDown() override {
1100 // Interpreter relies on delegate_ to free the resources properly. Thus
1101 // the life cycle of delegate must be longer than interpreter.
1102 interpreter_.reset();
1103 delegate_.reset();
1104 }
1105
1106 TfLiteBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle;
1107
AllocateBufferHandle()1108 TfLiteBufferHandle AllocateBufferHandle() { return ++last_allocated_handle_; }
1109
1110 protected:
1111 class SimpleDelegate {
1112 public:
1113 // Create a simple implementation of a TfLiteDelegate. We use the C++ class
1114 // SimpleDelegate and it can produce a handle TfLiteDelegate that is
1115 // value-copyable and compatible with TfLite.
SimpleDelegate(const std::vector<int> & nodes)1116 explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) {
1117 delegate_.Prepare = [](TfLiteContext* context,
1118 TfLiteDelegate* delegate) -> TfLiteStatus {
1119 auto* simple = reinterpret_cast<SimpleDelegate*>(delegate->data_);
1120 TfLiteIntArray* nodes_to_separate =
1121 TfLiteIntArrayCreate(simple->nodes_.size());
1122 // Mark nodes that we want in TfLiteIntArray* structure.
1123 int index = 0;
1124 for (auto node_index : simple->nodes_) {
1125 nodes_to_separate->data[index++] = node_index;
1126 // make sure node is add
1127 TfLiteNode* node;
1128 TfLiteRegistration* reg;
1129 context->GetNodeAndRegistration(context, node_index, &node, ®);
1130 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
1131 TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
1132 }
1133 // Check that all nodes are available
1134 TfLiteIntArray* execution_plan;
1135 TF_LITE_ENSURE_STATUS(
1136 context->GetExecutionPlan(context, &execution_plan));
1137 for (int exec_index = 0; exec_index < execution_plan->size;
1138 exec_index++) {
1139 int node_index = execution_plan->data[exec_index];
1140 // Check that we are an identity map to start.
1141 TFLITE_CHECK_EQ(exec_index, node_index);
1142 TfLiteNode* node;
1143 TfLiteRegistration* reg;
1144 context->GetNodeAndRegistration(context, node_index, &node, ®);
1145 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
1146 TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
1147 }
1148
1149 context->ReplaceNodeSubsetsWithDelegateKernels(
1150 context, FakeFusedRegistration(), nodes_to_separate, delegate);
1151 TfLiteIntArrayFree(nodes_to_separate);
1152 return kTfLiteOk;
1153 };
1154 delegate_.CopyToBufferHandle = [](TfLiteContext* context,
1155 TfLiteDelegate* delegate,
1156 TfLiteBufferHandle buffer_handle,
1157 TfLiteTensor* tensor) -> TfLiteStatus {
1158 // TODO(ycling): Implement tests to test buffer copying logic.
1159 return kTfLiteOk;
1160 };
1161 delegate_.CopyFromBufferHandle =
1162 [](TfLiteContext* context, TfLiteDelegate* delegate,
1163 TfLiteBufferHandle buffer_handle,
1164 TfLiteTensor* output) -> TfLiteStatus {
1165 // TODO(ycling): Implement tests to test buffer copying logic.
1166 return kTfLiteOk;
1167 };
1168 delegate_.FreeBufferHandle =
1169 [](TfLiteContext* context, TfLiteDelegate* delegate,
1170 TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; };
1171 // Store type-punned data SimpleDelegate structure.
1172 delegate_.data_ = reinterpret_cast<void*>(this);
1173 delegate_.flags = kTfLiteDelegateFlagsNone;
1174 }
1175
FakeFusedRegistration()1176 static TfLiteRegistration FakeFusedRegistration() {
1177 TfLiteRegistration reg = {nullptr};
1178 reg.custom_name = "fake_fused_op";
1179 return reg;
1180 }
1181
get_tf_lite_delegate()1182 TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; }
1183
1184 private:
1185 std::vector<int> nodes_;
1186 TfLiteDelegate delegate_;
1187 };
1188 std::unique_ptr<Interpreter> interpreter_;
1189 std::unique_ptr<SimpleDelegate> delegate_;
1190 };
1191
TEST_F(TestDelegate,BasicDelegate)1192 TEST_F(TestDelegate, BasicDelegate) {
1193 delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1194 interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
1195
1196 ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1197 int node = interpreter_->execution_plan()[0];
1198 const auto* node_and_reg = interpreter_->node_and_registration(node);
1199 EXPECT_EQ(node_and_reg->second.custom_name,
1200 SimpleDelegate::FakeFusedRegistration().custom_name);
1201
1202 const TfLiteDelegateParams* params =
1203 reinterpret_cast<const TfLiteDelegateParams*>(
1204 node_and_reg->first.builtin_data);
1205 ASSERT_EQ(params->nodes_to_replace->size, 3);
1206 EXPECT_EQ(params->nodes_to_replace->data[0], 0);
1207 EXPECT_EQ(params->nodes_to_replace->data[1], 1);
1208 EXPECT_EQ(params->nodes_to_replace->data[2], 2);
1209
1210 ASSERT_EQ(params->input_tensors->size, 2);
1211 EXPECT_EQ(params->input_tensors->data[0], 0);
1212 EXPECT_EQ(params->input_tensors->data[1], 1);
1213
1214 ASSERT_EQ(params->output_tensors->size, 2);
1215 EXPECT_EQ(params->output_tensors->data[0], 3);
1216 EXPECT_EQ(params->output_tensors->data[1], 4);
1217 }
1218
TEST_F(TestDelegate,StaticDelegateMakesGraphImmutable)1219 TEST_F(TestDelegate, StaticDelegateMakesGraphImmutable) {
1220 delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1221 ASSERT_EQ(
1222 interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
1223 kTfLiteOk);
1224 ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1225
1226 // As the delegate doesn't support dynamic resizing, further graph mutation is
1227 // prohibited.
1228 ASSERT_NE(interpreter_->ResizeInputTensor(0, {0}), kTfLiteOk);
1229 ASSERT_NE(
1230 interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
1231 kTfLiteOk);
1232 }
1233
TEST_F(TestDelegate,ComplexDelegate)1234 TEST_F(TestDelegate, ComplexDelegate) {
1235 delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2}));
1236 interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
1237
1238 ASSERT_EQ(interpreter_->execution_plan().size(), 2);
1239 // 0th should be a non-delegated original op
1240 ASSERT_EQ(interpreter_->execution_plan()[0], 0);
1241 // 1st should be a new macro op (3) which didn't exist)
1242 ASSERT_EQ(interpreter_->execution_plan()[1], 3);
1243 const auto* node_and_reg = interpreter_->node_and_registration(3);
1244 ASSERT_EQ(node_and_reg->second.custom_name,
1245 SimpleDelegate::FakeFusedRegistration().custom_name);
1246 }
1247
TEST_F(TestDelegate,SetBufferHandleToInput)1248 TEST_F(TestDelegate, SetBufferHandleToInput) {
1249 delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1250 TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
1251 interpreter_->ModifyGraphWithDelegate(delegate);
1252
1253 constexpr int kOutputTensorIndex = 0;
1254 TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
1255 ASSERT_EQ(tensor->delegate, nullptr);
1256 ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1257
1258 TfLiteBufferHandle handle = AllocateBufferHandle();
1259 TfLiteStatus status =
1260 interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate);
1261 ASSERT_EQ(status, kTfLiteOk);
1262 EXPECT_EQ(tensor->delegate, delegate);
1263 EXPECT_EQ(tensor->buffer_handle, handle);
1264 }
1265
TEST_F(TestDelegate,SetBufferHandleToOutput)1266 TEST_F(TestDelegate, SetBufferHandleToOutput) {
1267 delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1268 TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
1269 interpreter_->ModifyGraphWithDelegate(delegate);
1270
1271 constexpr int kOutputTensorIndex = 3;
1272 TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
1273 // Before setting the buffer handle, the tensor's `delegate` is already set
1274 // because it will be written by the delegate.
1275 ASSERT_EQ(tensor->delegate, delegate);
1276 ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1277
1278 TfLiteBufferHandle handle = AllocateBufferHandle();
1279 TfLiteStatus status =
1280 interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate);
1281 ASSERT_EQ(status, kTfLiteOk);
1282 EXPECT_EQ(tensor->delegate, delegate);
1283 EXPECT_EQ(tensor->buffer_handle, handle);
1284 }
1285
TEST_F(TestDelegate,SetInvalidHandleToTensor)1286 TEST_F(TestDelegate, SetInvalidHandleToTensor) {
1287 interpreter_->Invoke();
1288 delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1289 TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
1290 interpreter_->ModifyGraphWithDelegate(delegate);
1291
1292 SimpleDelegate another_simple_delegate({0, 1, 2});
1293
1294 constexpr int kOutputTensorIndex = 3;
1295 TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
1296 // Before setting the buffer handle, the tensor's `delegate` is already set
1297 // because it will be written by the delegate.
1298 ASSERT_EQ(tensor->delegate, delegate);
1299 ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1300
1301 TfLiteBufferHandle handle = AllocateBufferHandle();
1302 TfLiteStatus status = interpreter_->SetBufferHandle(
1303 kOutputTensorIndex, handle,
1304 another_simple_delegate.get_tf_lite_delegate());
1305 // Setting a buffer handle to a tensor with another delegate will fail.
1306 ASSERT_EQ(status, kTfLiteError);
1307 EXPECT_EQ(tensor->delegate, delegate);
1308 EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
1309 }
1310
TEST_F(TestDelegate,ResizeInputWithNonDynamicDelegateShouldFail)1311 TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) {
1312 delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
1313 ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteOk);
1314 ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 2}), kTfLiteOk);
1315 ASSERT_EQ(
1316 interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
1317 kTfLiteOk);
1318 ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError);
1319 }
1320
1321 class TestDelegateWithDynamicTensors : public ::testing::Test {
1322 protected:
SetUp()1323 void SetUp() override {
1324 interpreter_.reset(new Interpreter);
1325
1326 interpreter_->AddTensors(2);
1327 interpreter_->SetInputs({0});
1328 interpreter_->SetOutputs({1});
1329 TfLiteQuantizationParams quant;
1330 interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
1331 quant);
1332 interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
1333 quant);
1334 TfLiteRegistration reg = DynamicCopyOpRegistration();
1335 interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®);
1336
1337 delegate_.Prepare = [](TfLiteContext* context,
1338 TfLiteDelegate* delegate) -> TfLiteStatus {
1339 // In this test, the delegate replaces all the nodes if this function is
1340 // called.
1341 TfLiteIntArray* execution_plan;
1342 TF_LITE_ENSURE_STATUS(
1343 context->GetExecutionPlan(context, &execution_plan));
1344 context->ReplaceNodeSubsetsWithDelegateKernels(
1345 context, DelegateRegistration(), execution_plan, delegate);
1346 return kTfLiteOk;
1347 };
1348 delegate_.flags = kTfLiteDelegateFlagsNone;
1349 }
1350
DynamicCopyOpRegistration()1351 static TfLiteRegistration DynamicCopyOpRegistration() {
1352 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1353
1354 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1355 TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
1356 SetTensorToDynamic(output);
1357 return kTfLiteOk;
1358 };
1359
1360 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1361 // Not implemented since this isn't required in testing.
1362 return kTfLiteOk;
1363 };
1364 return reg;
1365 }
1366
DelegateRegistration()1367 static TfLiteRegistration DelegateRegistration() {
1368 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1369 return reg;
1370 }
1371
1372 std::unique_ptr<Interpreter> interpreter_;
1373 TfLiteDelegate delegate_;
1374 };
1375
TEST_F(TestDelegateWithDynamicTensors,DisallowDynamicTensors)1376 TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) {
1377 interpreter_->ModifyGraphWithDelegate(&delegate_);
1378
1379 ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1380 // The interpreter should not call delegate's `Prepare` when dynamic tensors
1381 // exist. So the node ID isn't changed.
1382 ASSERT_EQ(interpreter_->execution_plan()[0], 0);
1383 }
1384
TEST_F(TestDelegateWithDynamicTensors,AllowDynamicTensors)1385 TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) {
1386 delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
1387 interpreter_->ModifyGraphWithDelegate(&delegate_);
1388
1389 ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1390 // The node should be replaced because dynamic tensors are allowed. Therefore
1391 // only node ID in the execution plan is changed from 0 to 1.
1392 ASSERT_EQ(interpreter_->execution_plan()[0], 1);
1393 }
1394
TEST_F(TestDelegateWithDynamicTensors,ModifyGraphAfterAllocate)1395 TEST_F(TestDelegateWithDynamicTensors, ModifyGraphAfterAllocate) {
1396 // Trigger allocation *before* delegate application.
1397 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
1398
1399 delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
1400 ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(&delegate_), kTfLiteOk);
1401 ASSERT_EQ(interpreter_->execution_plan().size(), 1);
1402 ASSERT_EQ(interpreter_->execution_plan()[0], 1);
1403
1404 // Allocation should still succeed.
1405 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
1406 }
1407
TEST(TestDelegateOwnership,ProperlyDisposed)1408 TEST(TestDelegateOwnership, ProperlyDisposed) {
1409 struct TfLiteInterpreterOwnedDelegate : public TfLiteDelegate {
1410 TfLiteInterpreterOwnedDelegate(bool* destroyed, bool* prepared)
1411 : destroyed(destroyed), prepared(prepared) {
1412 flags = kTfLiteDelegateFlagsNone;
1413 Prepare = [](TfLiteContext*, TfLiteDelegate* delegate) -> TfLiteStatus {
1414 *static_cast<TfLiteInterpreterOwnedDelegate*>(delegate)->prepared =
1415 true;
1416 return kTfLiteOk;
1417 };
1418 }
1419 ~TfLiteInterpreterOwnedDelegate() { *destroyed = true; }
1420
1421 bool* destroyed;
1422 bool* prepared;
1423 };
1424
1425 // Construct a delegate with flags for indicating preparation/destruction.
1426 bool destroyed = false;
1427 bool prepared = false;
1428 std::unique_ptr<TfLiteInterpreterOwnedDelegate> delegate(
1429 new TfLiteInterpreterOwnedDelegate(&destroyed, &prepared));
1430 {
1431 // Create an interpreter and assemble a simple graph.
1432 Interpreter interpreter;
1433 TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
1434 ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
1435 ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
1436 ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
1437 ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
1438 ®istration),
1439 kTfLiteOk);
1440
1441 // Pass delegate ownership to that interpreter.
1442 ASSERT_EQ(InterpreterTest::ModifyGraphWithDelegate(&interpreter,
1443 std::move(delegate)),
1444 kTfLiteOk);
1445
1446 // The delegate should be prepared as normal, and should be preserved.
1447 EXPECT_TRUE(prepared);
1448 EXPECT_FALSE(destroyed);
1449
1450 // Interpreter interaction should not impact the delegate's validity.
1451 interpreter.AllocateTensors();
1452 interpreter.Invoke();
1453 EXPECT_FALSE(destroyed);
1454 }
1455
1456 // Only after the interpreter is destroyed should the delegate be destroyed.
1457 EXPECT_TRUE(destroyed);
1458 }
1459
1460 // CancellationData contains the data required to cancel a call to Invoke().
1461 struct CancellationData {
1462 bool is_cancelled = false;
1463 };
1464
1465 // Indicates whether Invoke() has been cancelled based on the value of the
1466 // CancellationData object passed in.
CheckCancellation(void * data)1467 bool CheckCancellation(void* data) {
1468 CancellationData* cancellation_data =
1469 static_cast<struct CancellationData*>(data);
1470 return cancellation_data->is_cancelled;
1471 }
1472
1473 static struct CancellationData cancellation_data_;
1474
1475 // Test fixture to test cancellation within the Interpreter.
1476 class CancellationTest : public ::testing::Test {
1477 public:
Invoke()1478 TfLiteStatus Invoke() { return interpreter_.Invoke(); }
Cancel()1479 void Cancel() { cancellation_data_.is_cancelled = true; }
1480
1481 // Adds an CancelOp with input tensor `input` and output tensor `output`.
MakeCancelNode(int input,int output)1482 void MakeCancelNode(int input, int output) {
1483 TfLiteRegistration op = CancelOpRegistration();
1484 ASSERT_EQ(interpreter_.AddNodeWithParameters({input}, {output}, nullptr, 0,
1485 nullptr, &op),
1486 kTfLiteOk);
1487 ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk);
1488 }
1489
1490 // Adds an OkOp with input tensor `input` and output tensor `output`.
MakeOkNode(int input,int output)1491 void MakeOkNode(int input, int output) {
1492 TfLiteRegistration op = OkOpRegistration();
1493 ASSERT_EQ(interpreter_.AddNodeWithParameters({input}, {output}, nullptr, 0,
1494 nullptr, &op),
1495 kTfLiteOk);
1496 ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk);
1497 }
1498
1499 Interpreter interpreter_;
1500
1501 private:
1502 // Build the kernel registration for an op that cancels the operation.
CancelOpRegistration()1503 TfLiteRegistration CancelOpRegistration() {
1504 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1505
1506 // Set output size to the input size in CancelOp::Prepare(). Code exists to
1507 // have a framework in Prepare. The input and output tensors are not used.
1508 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1509 TfLiteTensor* in_tensor = &context->tensors[node->inputs->data[0]];
1510 TfLiteTensor* out_tensor = &context->tensors[node->outputs->data[0]];
1511 TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
1512 return context->ResizeTensor(context, out_tensor, new_size);
1513 };
1514
1515 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1516 cancellation_data_.is_cancelled = true;
1517 return kTfLiteOk;
1518 };
1519 return reg;
1520 }
1521
1522 // Build the kernel registration for an op that returns kTfLiteOk.
OkOpRegistration()1523 TfLiteRegistration OkOpRegistration() {
1524 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
1525
1526 // Set output size to the input size in OkOp::Prepare(). Code exists to have
1527 // a framework in Prepare. The input and output tensors are not used.
1528 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
1529 TfLiteTensor* in_tensor = &context->tensors[node->inputs->data[0]];
1530 TfLiteTensor* out_tensor = &context->tensors[node->outputs->data[0]];
1531 TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
1532 return context->ResizeTensor(context, out_tensor, new_size);
1533 };
1534
1535 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
1536 return kTfLiteOk;
1537 };
1538 return reg;
1539 }
1540
SetUp()1541 void SetUp() final {
1542 cancellation_data_.is_cancelled = false;
1543
1544 // Set up the interpreter. Create the input and output tensors.
1545 int num_tensors = 3;
1546 ASSERT_EQ(interpreter_.AddTensors(num_tensors), kTfLiteOk);
1547 interpreter_.SetInputs({0});
1548 interpreter_.SetOutputs({2});
1549 TfLiteQuantizationParams quantized;
1550 for (int tensor_index = 0; tensor_index < num_tensors; tensor_index++) {
1551 ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
1552 tensor_index, kTfLiteFloat32, "", {3}, quantized),
1553 kTfLiteOk);
1554 }
1555 interpreter_.SetCancellationFunction(&cancellation_data_,
1556 &CheckCancellation);
1557 }
1558 };
1559
TEST_F(CancellationTest,CancelBeforeInvoke)1560 TEST_F(CancellationTest, CancelBeforeInvoke) {
1561 // Cancel prior to calling Invoke.
1562 CancellationTest::MakeOkNode(1, 2);
1563 ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
1564
1565 CancellationTest::Cancel();
1566 TfLiteStatus invoke_error_code = CancellationTest::Invoke();
1567 ASSERT_EQ(invoke_error_code, kTfLiteError);
1568 }
1569
TEST_F(CancellationTest,CancelDuringInvoke)1570 TEST_F(CancellationTest, CancelDuringInvoke) {
1571 // Tests a model which sets the cancel in order to test cancellation works
1572 // between ops.
1573 //
1574 // The first op will set the cancellation bit to true. The second op returns
1575 // `kTfLiteOk` if executed.
1576 CancellationTest::MakeCancelNode(0, 1);
1577 CancellationTest::MakeOkNode(1, 2);
1578 ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
1579
1580 TfLiteStatus invoke_error_code = CancellationTest::Invoke();
1581 ASSERT_EQ(invoke_error_code, kTfLiteError);
1582 }
1583
1584 } // namespace
1585 } // namespace tflite
1586
main(int argc,char ** argv)1587 int main(int argc, char** argv) {
1588 ::tflite::LogToStderr();
1589 ::testing::InitGoogleTest(&argc, argv);
1590 return RUN_ALL_TESTS();
1591 }
1592