1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include <cstdlib>
22 #include <utility>
23 #include <vector>
24
25 #include <gtest/gtest.h>
26 #include "absl/status/status.h"
27 #include "tensorflow/lite/builtin_ops.h"
28 #include "tensorflow/lite/core/subgraph.h"
29 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32 #include "tensorflow/lite/interpreter.h"
33
34 namespace tflite {
35 namespace gpu {
36 namespace {
37
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank0)38 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank0) {
39 TfLiteTensor tflite_tensor;
40 tflite_tensor.name = "tensor_name";
41 tflite_tensor.type = TfLiteType::kTfLiteFloat32;
42 tflite_tensor.dims = TfLiteIntArrayCreate(1);
43 tflite_tensor.dims->data[0] = 4;
44 TensorRef<BHWC> tensor_ref;
45 const auto status =
46 ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
47 TfLiteIntArrayFree(tflite_tensor.dims);
48 ASSERT_TRUE(status.ok());
49 EXPECT_EQ(tensor_ref.type, DataType::FLOAT32);
50 EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 1));
51 }
52
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank1)53 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank1) {
54 TfLiteTensor tflite_tensor;
55 tflite_tensor.name = "tensor_name";
56 tflite_tensor.type = TfLiteType::kTfLiteInt32;
57 tflite_tensor.dims = TfLiteIntArrayCreate(2);
58 tflite_tensor.dims->data[0] = 4;
59 tflite_tensor.dims->data[1] = 5;
60 TensorRef<BHWC> tensor_ref;
61 const auto status =
62 ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
63 TfLiteIntArrayFree(tflite_tensor.dims);
64 ASSERT_TRUE(status.ok());
65 EXPECT_EQ(tensor_ref.type, DataType::INT32);
66 EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 5));
67 }
68
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank2)69 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank2) {
70 TfLiteTensor tflite_tensor;
71 tflite_tensor.name = "tensor_name";
72 tflite_tensor.type = TfLiteType::kTfLiteInt64;
73 tflite_tensor.dims = TfLiteIntArrayCreate(3);
74 tflite_tensor.dims->data[0] = 4;
75 tflite_tensor.dims->data[1] = 5;
76 tflite_tensor.dims->data[2] = 6;
77 TensorRef<BHWC> tensor_ref;
78 const auto status =
79 ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
80 TfLiteIntArrayFree(tflite_tensor.dims);
81 ASSERT_TRUE(status.ok());
82 EXPECT_EQ(tensor_ref.type, DataType::INT64);
83 EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 5, 6));
84 }
85
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank3)86 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank3) {
87 TfLiteTensor tflite_tensor;
88 tflite_tensor.name = "tensor_name";
89 tflite_tensor.type = TfLiteType::kTfLiteUInt8;
90 tflite_tensor.dims = TfLiteIntArrayCreate(4);
91 tflite_tensor.dims->data[0] = 4;
92 tflite_tensor.dims->data[1] = 5;
93 tflite_tensor.dims->data[2] = 6;
94 tflite_tensor.dims->data[3] = 7;
95 TensorRef<BHWC> tensor_ref;
96 const auto status =
97 ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
98 TfLiteIntArrayFree(tflite_tensor.dims);
99 ASSERT_TRUE(status.ok());
100 EXPECT_EQ(tensor_ref.type, DataType::UINT8);
101 EXPECT_EQ(tensor_ref.shape, BHWC(4, 5, 6, 7));
102 }
103
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefFailsForRankLT0)104 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankLT0) {
105 TfLiteTensor tflite_tensor;
106 tflite_tensor.name = "tensor_name";
107 tflite_tensor.type = TfLiteType::kTfLiteFloat32;
108 tflite_tensor.dims = TfLiteIntArrayCreate(0);
109 TensorRef<BHWC> tensor_ref;
110 const auto status =
111 ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
112 TfLiteIntArrayFree(tflite_tensor.dims);
113 // TODO(b/130054481): Cover scalar.
114 EXPECT_FALSE(status.ok());
115 }
116
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefFailsForRankGT3)117 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) {
118 TfLiteTensor tflite_tensor;
119 tflite_tensor.name = "tensor_name";
120 tflite_tensor.type = TfLiteType::kTfLiteFloat32;
121 tflite_tensor.dims = TfLiteIntArrayCreate(5);
122 TensorRef<BHWC> tensor_ref;
123 const auto status =
124 ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
125 TfLiteIntArrayFree(tflite_tensor.dims);
126 EXPECT_FALSE(status.ok());
127 }
128
129 class DelegatedInterpreter {
130 public:
DelegatedInterpreter(int num_nodes)131 explicit DelegatedInterpreter(int num_nodes) {
132 exec_plan_ = TfLiteIntArrayCreate(num_nodes);
133 }
~DelegatedInterpreter()134 virtual ~DelegatedInterpreter() {
135 TfLiteIntArrayFree(exec_plan_);
136 for (auto params : delegate_params_) {
137 TfLiteIntArrayFree(params.nodes_to_replace);
138 TfLiteIntArrayFree(params.input_tensors);
139 TfLiteIntArrayFree(params.output_tensors);
140 }
141 }
142
143 // Get the TfLiteContext to be mocked for swapping out functions that have to
144 // be called inside delegate (i.e. in delegate kernel mode).
context()145 TfLiteContext* context() { return interpreter_.primary_subgraph().context(); }
146
147 // node(int) and registration(int) are used to implement
148 // GetNodeAndRegistration. We can't implement those using
149 // TfLiteContext *context = interpreter_.primary_subgraph().context();
150 // context->GetNodeAndRegistration(context, &node, ®istration);
151 // here, because calling GetNodeAndRegistration from within it's own
152 // implementation would lead to an infinite loop.
153 // Instead, we just call node_and_registration and use a const_cast.
154 // These const_casts are a bit ugly, but I think less ugly than exposing
155 // the private GetNodeAndRegistration method in Subgraph as public,
156 // or making this class a friend of Subgraph.
node(int index)157 TfLiteNode* node(int index) {
158 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
159 interpreter_.primary_subgraph().node_and_registration(index);
160 return const_cast<TfLiteNode*>(&node_and_registration->first);
161 }
registration(int index)162 TfLiteRegistration* registration(int index) {
163 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
164 interpreter_.primary_subgraph().node_and_registration(index);
165 return const_cast<TfLiteRegistration*>(&node_and_registration->second);
166 }
167
exec_plan() const168 TfLiteIntArray* exec_plan() const { return exec_plan_; }
add_delegate_params()169 TfLiteDelegateParams* add_delegate_params() {
170 delegate_params_.push_back(TfLiteDelegateParams());
171 return &delegate_params_.back();
172 }
delegate_params()173 TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); }
num_delegate_params()174 int num_delegate_params() { return delegate_params_.size(); }
175
176 protected:
177 Interpreter interpreter_;
178
179 private:
180 // The manually-set execution plan for this delegated interpreter.
181 TfLiteIntArray* exec_plan_;
182
183 // The TfLiteDelegateParams object that's manually populated inside the mocked
184 // TfLiteContext::PreviewDelegatePartitioning.
185 std::vector<TfLiteDelegateParams> delegate_params_;
186 };
187
188 class InterpreterFp16 : public DelegatedInterpreter {
189 public:
InterpreterFp16(TfLiteBuiltinOperator op,bool const_dequantize_inputs=true)190 explicit InterpreterFp16(TfLiteBuiltinOperator op,
191 bool const_dequantize_inputs = true)
192 : DelegatedInterpreter(3) {
193 void* builtin_data = malloc(sizeof(int));
194 EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk);
195 EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
196 EXPECT_EQ(interpreter_.SetOutputs({4}), kTfLiteOk);
197
198 // Add a Dequantize Node.
199 const TfLiteRegistration reg_dequant0 = {
200 nullptr, nullptr, nullptr, nullptr, nullptr, kTfLiteBuiltinDequantize};
201 EXPECT_EQ(interpreter_.AddNodeWithParameters(
202 /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
203 /*init_data_size=*/0, /*builtin_data=*/nullptr,
204 /*registration=*/®_dequant0),
205 kTfLiteOk);
206
207 // Add a Dequantize Node.
208 const TfLiteRegistration reg_dequant1 = {
209 nullptr, nullptr, nullptr, nullptr, nullptr, kTfLiteBuiltinDequantize};
210 EXPECT_EQ(interpreter_.AddNodeWithParameters(
211 /*inputs=*/{2}, /*outputs=*/{3}, /*init_data=*/nullptr,
212 /*init_data_size=*/0, /*builtin_data=*/nullptr,
213 /*registration=*/®_dequant1),
214 kTfLiteOk);
215
216 // Add a node that GPU delegate can parse.
217 const TfLiteRegistration reg_op0 = {
218 [](TfLiteContext* context, const char* buffer, size_t length) {
219 return reinterpret_cast<void*>(new int(1));
220 },
221 [](TfLiteContext* context, void* buffer) {
222 delete reinterpret_cast<int*>(buffer);
223 },
224 nullptr,
225 nullptr,
226 nullptr,
227 op};
228 EXPECT_EQ(interpreter_.AddNodeWithParameters(
229 /*inputs=*/{1, 3}, /*outputs=*/{4}, /*init_data=*/nullptr,
230 /*init_data_size=*/0,
231 /*builtin_data=*/builtin_data,
232 /*registration=*/®_op0),
233 kTfLiteOk);
234
235 // Set inputs to Dequantize node to the fp16 type, and outputs
236 // to fp32 type.
237 const std::vector<int> dims = {1};
238 TfLiteQuantization quantization;
239 quantization.type = kTfLiteNoQuantization;
240 EXPECT_EQ(
241 interpreter_.SetTensorParametersReadWrite(
242 0, TfLiteType::kTfLiteFloat16, "t0", dims, quantization, false),
243 kTfLiteOk);
244 EXPECT_EQ(
245 interpreter_.SetTensorParametersReadWrite(
246 2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
247 kTfLiteOk);
248 if (const_dequantize_inputs) {
249 // This simulates the dequantize inputs being constants in the graph.
250 // If this is not true, FP16GraphPartitionHelper should not consider the
251 // corresponding DEQUANTIZE ops.
252 auto* tensor0 = interpreter_.tensor(0);
253 auto* tensor2 = interpreter_.tensor(2);
254 tensor0->allocation_type = kTfLiteMmapRo;
255 tensor2->allocation_type = kTfLiteMmapRo;
256 }
257 EXPECT_EQ(
258 interpreter_.SetTensorParametersReadWrite(
259 1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
260 kTfLiteOk);
261 EXPECT_EQ(
262 interpreter_.SetTensorParametersReadWrite(
263 3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
264 kTfLiteOk);
265
266 exec_plan()->data[0] = 0;
267 exec_plan()->data[1] = 1;
268 exec_plan()->data[2] = 2;
269 }
270 };
271
272 // **NOTE**: we have several interpreter instances created at global scope to
273 // test *exactly* the GetOpsToReplace function alone, and not the sequence of
274 // function calls that includes GetOpsToReplace when calling
275 // ModifyGraphWithDelegate. A TfLiteContext is needed to test GetOpsToReplace,
276 // but TfLiteContexts intentionally make it difficult to call certain functions
277 // in a non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for
278 // details) We create our own GetExecutionPlan, GetNodeAndRegistration and
279 // PreviewDelegatePartitioning lambdas inside each test, but we can't use local
280 // captures without changing the function signature. Therefore, this test data
281 // lives at global scope in order to be accessible inside the lambda.
282
283 InterpreterFp16* interpreter_fp16_add_op =
284 new InterpreterFp16(kTfLiteBuiltinAdd);
285
TEST(ModelBuilderTest,GetOpsToReplaceAcceptsFp16DequantizeNodes)286 TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
287 // Before pruning, the graph has three nodes:
288 //
289 // t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
290 // t2 (FP16) -> DequantNode -> t3 (FP32) --/
291 //
292 // OpsToReplace should choose all three nodes for replacement, and
293 // the graph on the GPU will look like this (no Dequants):
294 //
295 // t0 (FP16) --> Add -> t4
296 // t2 (FP16) --/
297 //
298 TfLiteContext* context = interpreter_fp16_add_op->context();
299
300 // These functions are meant to be called inside delegates. Swap out
301 // for similar functions to permit direct calling of GetOpsToReplace.
302 context->GetExecutionPlan = [](struct TfLiteContext* context,
303 TfLiteIntArray** execution_plan) {
304 *execution_plan = interpreter_fp16_add_op->exec_plan();
305 return kTfLiteOk;
306 };
307 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
308 TfLiteNode** node,
309 TfLiteRegistration** registration) {
310 *node = interpreter_fp16_add_op->node(node_index);
311 *registration = interpreter_fp16_add_op->registration(node_index);
312 return kTfLiteOk;
313 };
314 context->PreviewDelegatePartitioning =
315 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
316 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
317 // The partitioner should accept only the Add op initially.
318 EXPECT_EQ(nodes_to_replace->size, 1);
319 // Single partition output.
320 auto params = interpreter_fp16_add_op->add_delegate_params();
321 params->nodes_to_replace = TfLiteIntArrayCreate(1);
322 params->nodes_to_replace->data[0] = 2;
323 params->input_tensors = TfLiteIntArrayCreate(2);
324 params->input_tensors->data[0] = 1;
325 params->input_tensors->data[1] = 3;
326 params->output_tensors = TfLiteIntArrayCreate(1);
327 params->output_tensors->data[0] = 4;
328
329 *partition_params_array = interpreter_fp16_add_op->delegate_params();
330 *num_partitions = interpreter_fp16_add_op->num_delegate_params();
331 return kTfLiteOk;
332 };
333
334 TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
335
336 // The Dequant nodes are added to ops_to_replace as a post-processing step by
337 // the FP16GraphPartitioner. ADD is delegated with its inputs pointing to the
338 // FP16 inputs.
339 EXPECT_EQ(ops_to_replace->size, 3);
340 TfLiteNode* node = nullptr;
341 TfLiteRegistration* registration = nullptr;
342 context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
343 ®istration);
344 EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
345 TfLiteType::kTfLiteFloat16);
346 EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
347 TfLiteType::kTfLiteFloat16);
348 TfLiteIntArrayFree(ops_to_replace);
349 }
350
351 InterpreterFp16* interpreter_fp16_non_constant =
352 new InterpreterFp16(kTfLiteBuiltinAdd, /*const_dequantize_inputs=*/false);
353
354 // Same as GetOpsToReplaceAcceptsFp16DequantizeNodes, but the DEQUANTIZE inputs
355 // are not constant. As a result, we don't allow the delegate to accept them.
TEST(ModelBuilderTest,GetOpsToReplaceRejectsNonConstantFp16DequantizeNodes)356 TEST(ModelBuilderTest, GetOpsToReplaceRejectsNonConstantFp16DequantizeNodes) {
357 TfLiteContext* context = interpreter_fp16_non_constant->context();
358
359 // These functions are meant to be called inside delegates. Swap out
360 // for similar functions to permit direct calling of GetOpsToReplace.
361 context->GetExecutionPlan = [](struct TfLiteContext* context,
362 TfLiteIntArray** execution_plan) {
363 *execution_plan = interpreter_fp16_non_constant->exec_plan();
364 return kTfLiteOk;
365 };
366 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
367 TfLiteNode** node,
368 TfLiteRegistration** registration) {
369 *node = interpreter_fp16_non_constant->node(node_index);
370 *registration = interpreter_fp16_non_constant->registration(node_index);
371 return kTfLiteOk;
372 };
373 context->PreviewDelegatePartitioning =
374 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
375 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
376 // The partitioner should accept only the Add op initially.
377 EXPECT_EQ(nodes_to_replace->size, 1);
378 // Single partition output.
379 auto params = interpreter_fp16_non_constant->add_delegate_params();
380 params->nodes_to_replace = TfLiteIntArrayCreate(1);
381 params->nodes_to_replace->data[0] = 2;
382 params->input_tensors = TfLiteIntArrayCreate(2);
383 params->input_tensors->data[0] = 1;
384 params->input_tensors->data[1] = 3;
385 params->output_tensors = TfLiteIntArrayCreate(1);
386 params->output_tensors->data[0] = 4;
387
388 *partition_params_array =
389 interpreter_fp16_non_constant->delegate_params();
390 *num_partitions = interpreter_fp16_non_constant->num_delegate_params();
391 return kTfLiteOk;
392 };
393
394 TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
395
396 // Only ADD is delegated, with FP32 (dequantized) inputs.
397 EXPECT_EQ(ops_to_replace->size, 1);
398 TfLiteNode* node = nullptr;
399 TfLiteRegistration* registration = nullptr;
400 context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
401 ®istration);
402 EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
403 TfLiteType::kTfLiteFloat32);
404 EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
405 TfLiteType::kTfLiteFloat32);
406 TfLiteIntArrayFree(ops_to_replace);
407 }
408
409 InterpreterFp16* interpreter_fp16_gt_op =
410 new InterpreterFp16(kTfLiteBuiltinGreater);
411
TEST(ModelBuilderTest,GetOpsToReplaceRejectsFp16DequantizeNodes)412 TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) {
413 // Before pruning, the graph has three nodes:
414 //
415 // t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4
416 // t2 (FP16) -> DequantNode -> t3 (FP32) --/
417 //
418 // Because there is no GPU equivalent for the Greater op, we don't choose any
419 // nodes.
420
421 TfLiteContext* context = interpreter_fp16_gt_op->context();
422 // These functions are meant to be called inside delegates. Swap out
423 // for similar functions to permit direct calling of GetOpsToReplace.
424 context->GetExecutionPlan = [](struct TfLiteContext* context,
425 TfLiteIntArray** execution_plan) {
426 *execution_plan = interpreter_fp16_gt_op->exec_plan();
427 return kTfLiteOk;
428 };
429 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
430 TfLiteNode** node,
431 TfLiteRegistration** registration) {
432 *node = interpreter_fp16_gt_op->node(node_index);
433 *registration = interpreter_fp16_gt_op->registration(node_index);
434 return kTfLiteOk;
435 };
436 context->PreviewDelegatePartitioning =
437 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
438 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
439 // No selected nodes.
440 EXPECT_EQ(nodes_to_replace->size, 0);
441 *partition_params_array = nullptr;
442 *num_partitions = 0;
443 return kTfLiteOk;
444 };
445
446 TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
447
448 // No nodes were found to replace.
449 EXPECT_EQ(ops_to_replace->size, 0);
450 // Inputs to Greater op are still fp32.
451 TfLiteNode* node = nullptr;
452 TfLiteRegistration* registration = nullptr;
453 const int kGreaterOpIndex = 2;
454 context->GetNodeAndRegistration(context, kGreaterOpIndex, &node,
455 ®istration);
456 EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
457 TfLiteType::kTfLiteFloat32);
458 EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
459 TfLiteType::kTfLiteFloat32);
460 TfLiteIntArrayFree(ops_to_replace);
461 }
462
463 class InterpreterFp32 : public DelegatedInterpreter {
464 public:
InterpreterFp32()465 InterpreterFp32() : DelegatedInterpreter(2) {
466 void* builtin_data = malloc(sizeof(int));
467 EXPECT_EQ(interpreter_.AddTensors(4), kTfLiteOk);
468 EXPECT_EQ(interpreter_.SetInputs({0, 2}), kTfLiteOk);
469 EXPECT_EQ(interpreter_.SetOutputs({3}), kTfLiteOk);
470
471 // Add a Dequantize Node with uint8 input.
472 const TfLiteRegistration reg_dequant0 = {/*init=*/nullptr,
473 /*free=*/nullptr,
474 /*prepare=*/nullptr,
475 /*invoke=*/nullptr,
476 /*profiling_string=*/nullptr,
477 kTfLiteBuiltinDequantize};
478 EXPECT_EQ(interpreter_.AddNodeWithParameters(
479 /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
480 /*init_data_size=*/0, /*builtin_data=*/nullptr,
481 /*registration=*/®_dequant0),
482 kTfLiteOk);
483
484 // Add a node that GPU delegate can parse.
485 const TfLiteRegistration reg_add0 = {
486 [](TfLiteContext* context, const char* buffer, size_t length) {
487 return reinterpret_cast<void*>(new int(1));
488 },
489 [](TfLiteContext* context, void* buffer) {
490 delete reinterpret_cast<int*>(buffer);
491 },
492 nullptr,
493 nullptr,
494 nullptr,
495 kTfLiteBuiltinAdd};
496 EXPECT_EQ(interpreter_.AddNodeWithParameters(
497 /*inputs=*/{1, 2}, /*outputs=*/{3}, /*init_data=*/nullptr,
498 /*init_data_size=*/0,
499 /*builtin_data=*/builtin_data,
500 /*registration=*/®_add0),
501 kTfLiteOk);
502
503 const std::vector<int> dims = {1};
504 TfLiteQuantization quantization;
505 quantization.type = kTfLiteNoQuantization;
506 EXPECT_EQ(interpreter_.SetTensorParametersReadWrite(
507 0, TfLiteType::kTfLiteUInt8, "t0", dims, quantization, false),
508 kTfLiteOk);
509 EXPECT_EQ(
510 interpreter_.SetTensorParametersReadWrite(
511 1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
512 kTfLiteOk);
513 EXPECT_EQ(
514 interpreter_.SetTensorParametersReadWrite(
515 2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false),
516 kTfLiteOk);
517
518 exec_plan()->data[0] = 0;
519 exec_plan()->data[1] = 1;
520 }
521 };
522
523 InterpreterFp32* interpreter_fp32 = new InterpreterFp32();
524
TEST(ModelBuilderTest,GetOpsToReplaceDoesNotPruneUint8)525 TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
526 // A graph with a Dequant node with uint8 input is not pruned. As this op is
527 // currently not supported on the GPU. Therefore, the Dequant op will be
528 // scheduled to run on the CPU while the remaining supported op Add on the
529 // GPU.
530 //
531 // t0 (uint8) --> Dequant --> t1 (FP32) --> Add -> t3
532 // t2 (FP32) --/
533 //
534 TfLiteContext* context = interpreter_fp32->context();
535
536 // These functions are meant to be called inside delegates. Swap out
537 // for similar functions to permit direct calling of GetOpsToReplace.
538 context->GetExecutionPlan = [](struct TfLiteContext* context,
539 TfLiteIntArray** execution_plan) {
540 *execution_plan = interpreter_fp32->exec_plan();
541 return kTfLiteOk;
542 };
543 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
544 TfLiteNode** node,
545 TfLiteRegistration** registration) {
546 *node = interpreter_fp32->node(node_index);
547 *registration = interpreter_fp32->registration(node_index);
548 return kTfLiteOk;
549 };
550 context->PreviewDelegatePartitioning =
551 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
552 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
553 auto params = interpreter_fp32->add_delegate_params();
554 params->nodes_to_replace = TfLiteIntArrayCreate(1);
555 params->nodes_to_replace->data[0] = 1;
556 params->input_tensors = TfLiteIntArrayCreate(2);
557 params->input_tensors->data[0] = 1;
558 params->input_tensors->data[1] = 2;
559 params->output_tensors = TfLiteIntArrayCreate(1);
560 params->output_tensors->data[0] = 3;
561
562 *partition_params_array = interpreter_fp32->delegate_params();
563 *num_partitions = interpreter_fp32->num_delegate_params();
564 return kTfLiteOk;
565 };
566
567 TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
568
569 // As the Dequant op is not pruned and the ADD op could run on GPU, we have
570 // 1 partition.
571 EXPECT_EQ(ops_to_replace->size, 1);
572 // ADD at index 1.
573 EXPECT_EQ(1, ops_to_replace->data[0]);
574
575 TfLiteIntArrayFree(ops_to_replace);
576 }
577
578 class Interpreter2Fp32 : public DelegatedInterpreter {
579 public:
Interpreter2Fp32()580 Interpreter2Fp32() : DelegatedInterpreter(4) {
581 void* builtin_data = malloc(sizeof(int));
582 EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
583 EXPECT_EQ(interpreter_.SetInputs({0, 2, 4, 6}), kTfLiteOk);
584 EXPECT_EQ(interpreter_.SetOutputs({7}), kTfLiteOk);
585
586 // Add a Dequantize Node with uint8 input.
587 const TfLiteRegistration reg_dequant = {/*init=*/nullptr,
588 /*free=*/nullptr,
589 /*prepare=*/nullptr,
590 /*invoke=*/nullptr,
591 /*profiling_string=*/nullptr,
592 kTfLiteBuiltinDequantize};
593 EXPECT_EQ(interpreter_.AddNodeWithParameters(
594 /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
595 /*init_data_size=*/0, /*builtin_data=*/nullptr,
596 /*registration=*/®_dequant),
597 kTfLiteOk);
598
599 // Add an ADD node that GPU delegate can parse.
600 const TfLiteRegistration reg_add0 = {
601 [](TfLiteContext* context, const char* buffer, size_t length) {
602 return reinterpret_cast<void*>(new int(1));
603 },
604 [](TfLiteContext* context, void* buffer) {
605 delete reinterpret_cast<int*>(buffer);
606 },
607 nullptr,
608 nullptr,
609 nullptr,
610 kTfLiteBuiltinAdd};
611 EXPECT_EQ(interpreter_.AddNodeWithParameters(
612 /*inputs=*/{1, 2}, /*outputs=*/{3}, /*init_data=*/nullptr,
613 /*init_data_size=*/0,
614 /*builtin_data=*/builtin_data,
615 /*registration=*/®_add0),
616 kTfLiteOk);
617
618 // Add a Pack Node that GPU delegate doesn't support
619 const TfLiteRegistration reg_pack = {/*init=*/nullptr,
620 /*free=*/nullptr,
621 /*prepare=*/nullptr,
622 /*invoke=*/nullptr,
623 /*profiling_string=*/nullptr,
624 kTfLiteBuiltinPack};
625 EXPECT_EQ(interpreter_.AddNodeWithParameters(
626 /*inputs=*/{3, 4}, /*outputs=*/{5}, /*init_data=*/nullptr,
627 /*init_data_size=*/0, /*builtin_data=*/nullptr,
628 /*registration=*/®_pack),
629 kTfLiteOk);
630
631 const TfLiteRegistration reg_add1 = {
632 [](TfLiteContext* context, const char* buffer, size_t length) {
633 return reinterpret_cast<void*>(new int[2]);
634 },
635 [](TfLiteContext* context, void* buffer) {
636 delete reinterpret_cast<int*>(buffer);
637 },
638 nullptr,
639 nullptr,
640 nullptr,
641 kTfLiteBuiltinAdd};
642 EXPECT_EQ(interpreter_.AddNodeWithParameters(
643 /*inputs=*/{5, 6}, /*outputs=*/{7}, /*init_data=*/nullptr,
644 /*init_data_size=*/0,
645 /*builtin_data=*/builtin_data,
646 /*registration=*/®_add1),
647 kTfLiteOk);
648
649 std::vector<int> dims = {1};
650 TfLiteQuantization quantization;
651 quantization.type = kTfLiteNoQuantization;
652 EXPECT_EQ(interpreter_.SetTensorParametersReadWrite(
653 0, TfLiteType::kTfLiteUInt8, "t0", dims, quantization, false),
654 kTfLiteOk);
655 EXPECT_EQ(
656 interpreter_.SetTensorParametersReadWrite(
657 1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
658 kTfLiteOk);
659 EXPECT_EQ(
660 interpreter_.SetTensorParametersReadWrite(
661 2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false),
662 kTfLiteOk);
663 EXPECT_EQ(
664 interpreter_.SetTensorParametersReadWrite(
665 3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
666 kTfLiteOk);
667 EXPECT_EQ(
668 interpreter_.SetTensorParametersReadWrite(
669 4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false),
670 kTfLiteOk);
671
672 dims.push_back(2);
673 EXPECT_EQ(
674 interpreter_.SetTensorParametersReadWrite(
675 5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
676 kTfLiteOk);
677 EXPECT_EQ(
678 interpreter_.SetTensorParametersReadWrite(
679 6, TfLiteType::kTfLiteFloat32, "t6", dims, quantization, false),
680 kTfLiteOk);
681
682 exec_plan()->data[0] = 0;
683 exec_plan()->data[1] = 1;
684 exec_plan()->data[2] = 2;
685 exec_plan()->data[3] = 3;
686 }
687 };
688
689 Interpreter2Fp32* interpreter2_fp32 = new Interpreter2Fp32();
690
TEST(ModelBuilderTest,GetOpsToReplaceMultiplePartitions)691 TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) {
692 // A graph with a Dequant node with uint8 input, a Pack node are not pruned.
693 // As these ops are currently not supported on the GPU, they will be scheduled
694 // to run on the CPU while the remaining supported op Add on the GPU.
695 //
696 // t0 (uint8) -> Dequant(0) -> t1 (FP32) -> Add(1) -> t3 (FP32) -> PACK (2)
697 // t2 (FP32) -/ t4 (FP32) -/
698 // PACK (2) -> t5 (FP32) -> Add(3) -> t7
699 // -> t6 (FP32) -/
700 //
701 TfLiteContext* context = interpreter2_fp32->context();
702
703 // These functions are meant to be called inside delegates. Swap out
704 // for similar functions to permit direct calling of GetOpsToReplace.
705 context->GetExecutionPlan = [](struct TfLiteContext* context,
706 TfLiteIntArray** execution_plan) {
707 *execution_plan = interpreter2_fp32->exec_plan();
708 return kTfLiteOk;
709 };
710 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
711 TfLiteNode** node,
712 TfLiteRegistration** registration) {
713 *node = interpreter2_fp32->node(node_index);
714 *registration = interpreter2_fp32->registration(node_index);
715 return kTfLiteOk;
716 };
717 context->PreviewDelegatePartitioning =
718 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
719 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
720 auto params = interpreter2_fp32->add_delegate_params();
721 params->nodes_to_replace = TfLiteIntArrayCreate(1);
722 params->nodes_to_replace->data[0] = 1;
723 params->input_tensors = TfLiteIntArrayCreate(2);
724 params->input_tensors->data[0] = 1;
725 params->input_tensors->data[1] = 2;
726 params->output_tensors = TfLiteIntArrayCreate(1);
727 params->output_tensors->data[0] = 3;
728
729 params = interpreter2_fp32->add_delegate_params();
730 params->nodes_to_replace = TfLiteIntArrayCreate(1);
731 params->nodes_to_replace->data[0] = 3;
732 params->input_tensors = TfLiteIntArrayCreate(2);
733 params->input_tensors->data[0] = 5;
734 params->input_tensors->data[1] = 6;
735 params->output_tensors = TfLiteIntArrayCreate(1);
736 params->output_tensors->data[0] = 7;
737
738 *partition_params_array = interpreter2_fp32->delegate_params();
739 *num_partitions = interpreter2_fp32->num_delegate_params();
740 return kTfLiteOk;
741 };
742
743 TfLiteIntArray* ops_to_replace = GetOpsToReplace(
744 context, /*allow_quant_ops=*/false, /*max_delegated_partitions*/ 2);
745
746 // As the Dequant op is not pruned and the ADD op could run on GPU, we have
747 // 2 partitions.
748 EXPECT_EQ(ops_to_replace->size, 2);
749 // ADD at index 1.
750 EXPECT_EQ(1, ops_to_replace->data[0]);
751 // ADD at index 3.
752 EXPECT_EQ(3, ops_to_replace->data[1]);
753
754 TfLiteIntArrayFree(ops_to_replace);
755 }
756
757 class InterpreterMultiNode : public DelegatedInterpreter {
758 public:
InterpreterMultiNode(bool both_ops_supported=true)759 explicit InterpreterMultiNode(bool both_ops_supported = true)
760 : DelegatedInterpreter(5) {
761 void* builtin_data = malloc(sizeof(int));
762 EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
763 EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
764 EXPECT_EQ(interpreter_.SetOutputs({6, 7}), kTfLiteOk);
765
766 // Add 3 Dequantize Nodes with float16 input.
767 for (int i = 0; i < 3; ++i) {
768 const TfLiteRegistration reg_dequant = {/*init=*/nullptr,
769 /*free=*/nullptr,
770 /*prepare=*/nullptr,
771 /*invoke=*/nullptr,
772 /*profiling_string=*/nullptr,
773 kTfLiteBuiltinDequantize};
774 EXPECT_EQ(interpreter_.AddNodeWithParameters(
775 /*inputs=*/{i}, /*outputs=*/{i + 3}, /*init_data=*/nullptr,
776 /*init_data_size=*/0, /*builtin_data=*/nullptr,
777 /*registration=*/®_dequant),
778 kTfLiteOk);
779 }
780
781 if (both_ops_supported) {
782 // Add 2 ADD ops.
783 const TfLiteRegistration reg_add0 = {
784 [](TfLiteContext* context, const char* buffer, size_t length) {
785 return reinterpret_cast<void*>(new int(1));
786 },
787 [](TfLiteContext* context, void* buffer) {
788 delete reinterpret_cast<int*>(buffer);
789 },
790 nullptr,
791 nullptr,
792 nullptr,
793 kTfLiteBuiltinAdd};
794 EXPECT_EQ(interpreter_.AddNodeWithParameters(
795 /*inputs=*/{4, 5}, /*outputs=*/{7}, /*init_data=*/nullptr,
796 /*init_data_size=*/0,
797 /*builtin_data=*/builtin_data,
798 /*registration=*/®_add0),
799 kTfLiteOk);
800
801 const TfLiteRegistration reg_add1 = {
802 [](TfLiteContext* context, const char* buffer, size_t length) {
803 return reinterpret_cast<void*>(new int(1));
804 },
805 [](TfLiteContext* context, void* buffer) {
806 delete reinterpret_cast<int*>(buffer);
807 },
808 nullptr,
809 nullptr,
810 nullptr,
811 kTfLiteBuiltinAdd};
812 EXPECT_EQ(interpreter_.AddNodeWithParameters(
813 /*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
814 /*init_data_size=*/0,
815 /*builtin_data=*/builtin_data,
816 /*registration=*/®_add1),
817 kTfLiteOk);
818 } else {
819 // Add the GREATER op node that GPU delegate doesn't support.
820 const TfLiteRegistration reg_greater = {
821 [](TfLiteContext* context, const char* buffer, size_t length) {
822 return reinterpret_cast<void*>(new int(1));
823 },
824 [](TfLiteContext* context, void* buffer) {
825 delete reinterpret_cast<int*>(buffer);
826 },
827 nullptr,
828 nullptr,
829 nullptr,
830 kTfLiteBuiltinGreater};
831 EXPECT_EQ(interpreter_.AddNodeWithParameters(
832 /*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
833 /*init_data_size=*/0,
834 /*builtin_data=*/builtin_data,
835 /*registration=*/®_greater),
836 kTfLiteOk);
837
838 // Add the ADD op node that GPU delegate supports.
839 const TfLiteRegistration reg_add0 = {
840 [](TfLiteContext* context, const char* buffer, size_t length) {
841 return reinterpret_cast<void*>(new int(1));
842 },
843 [](TfLiteContext* context, void* buffer) {
844 delete reinterpret_cast<int*>(buffer);
845 },
846 nullptr,
847 nullptr,
848 nullptr,
849 kTfLiteBuiltinAdd};
850 EXPECT_EQ(interpreter_.AddNodeWithParameters(
851 /*inputs=*/{4, 5}, /*outputs=*/{7}, /*init_data=*/nullptr,
852 /*init_data_size=*/0,
853 /*builtin_data=*/builtin_data,
854 /*registration=*/®_add0),
855 kTfLiteOk);
856 }
857 const std::vector<int> dims = {1};
858 TfLiteQuantization quantization;
859 quantization.type = kTfLiteNoQuantization;
860 EXPECT_EQ(
861 interpreter_.SetTensorParametersReadWrite(
862 0, TfLiteType::kTfLiteFloat16, "t0", dims, quantization, false),
863 kTfLiteOk);
864 EXPECT_EQ(
865 interpreter_.SetTensorParametersReadWrite(
866 1, TfLiteType::kTfLiteFloat16, "t1", dims, quantization, false),
867 kTfLiteOk);
868 EXPECT_EQ(
869 interpreter_.SetTensorParametersReadWrite(
870 2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
871 kTfLiteOk);
872 // Simulate DEQUANTIZE inputs being constants.
873 auto* tensor0 = interpreter_.tensor(0);
874 auto* tensor1 = interpreter_.tensor(1);
875 auto* tensor2 = interpreter_.tensor(2);
876 tensor0->allocation_type = kTfLiteMmapRo;
877 tensor1->allocation_type = kTfLiteMmapRo;
878 tensor2->allocation_type = kTfLiteMmapRo;
879 EXPECT_EQ(
880 interpreter_.SetTensorParametersReadWrite(
881 3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
882 kTfLiteOk);
883 EXPECT_EQ(
884 interpreter_.SetTensorParametersReadWrite(
885 4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false),
886 kTfLiteOk);
887 EXPECT_EQ(
888 interpreter_.SetTensorParametersReadWrite(
889 5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
890 kTfLiteOk);
891 EXPECT_EQ(
892 interpreter_.SetTensorParametersReadWrite(
893 6, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
894 kTfLiteOk);
895 EXPECT_EQ(
896 interpreter_.SetTensorParametersReadWrite(
897 7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
898 kTfLiteOk);
899
900 exec_plan()->data[0] = 0;
901 exec_plan()->data[1] = 1;
902 exec_plan()->data[2] = 2;
903 exec_plan()->data[3] = 3;
904 exec_plan()->data[4] = 4;
905 }
906 };
907
908 InterpreterMultiNode* interpreter_mn =
909 new InterpreterMultiNode(/*both_ops_supported*/ false);
910
TEST(ModelBuilderTest,GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition)911 TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) {
912 // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
913 // 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
914 // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
915 // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
916 // --\
917 // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
918 //
919 // OpsToReplace should accept 'Add' & the Dequant nodes that only output to
920 // it (in this case, Dequant(2)).
921 TfLiteContext* context = interpreter_mn->context();
922
923 // These functions are meant to be called inside delegates. Swap out
924 // for similar functions to permit direct calling of GetOpsToReplace.
925 context->GetExecutionPlan = [](struct TfLiteContext* context,
926 TfLiteIntArray** execution_plan) {
927 *execution_plan = interpreter_mn->exec_plan();
928 return kTfLiteOk;
929 };
930 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
931 TfLiteNode** node,
932 TfLiteRegistration** registration) {
933 *node = interpreter_mn->node(node_index);
934 *registration = interpreter_mn->registration(node_index);
935 return kTfLiteOk;
936 };
937 context->PreviewDelegatePartitioning =
938 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
939 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
940 // The FP16GraphPartitioner should only mark the ADD op as accepted.
941 EXPECT_EQ(nodes_to_replace->size, 1);
942 EXPECT_EQ(nodes_to_replace->data[0], 4);
943 // Single partition.
944 auto params = interpreter_mn->add_delegate_params();
945 params->nodes_to_replace = TfLiteIntArrayCreate(1);
946 params->nodes_to_replace->data[0] = 4;
947 params->input_tensors = TfLiteIntArrayCreate(2);
948 params->input_tensors->data[0] = 1;
949 params->input_tensors->data[1] = 3;
950 params->output_tensors = TfLiteIntArrayCreate(1);
951 params->output_tensors->data[0] = 7;
952
953 *partition_params_array = interpreter_mn->delegate_params();
954 *num_partitions = interpreter_mn->num_delegate_params();
955 return kTfLiteOk;
956 };
957
958 TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
959
960 // Post-PreviewDelegatePartitioning, the partitioner will add Dequant(2) to
961 // ops_to_replace, since it only outputs to a delegated node.
962 EXPECT_EQ(ops_to_replace->size, 2);
963 // Op at index 4 is the Add op.
964 EXPECT_EQ(ops_to_replace->data[0], 4);
965 EXPECT_EQ(ops_to_replace->data[1], 2);
966 // Verify that Add op has fp16 inputs.
967 TfLiteNode* node = nullptr;
968 TfLiteRegistration* registration = nullptr;
969 context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
970 ®istration);
971 EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
972 TfLiteType::kTfLiteFloat16);
973 EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
974 TfLiteType::kTfLiteFloat16);
975 TfLiteIntArrayFree(ops_to_replace);
976 }
977
978 InterpreterMultiNode* interpreter_mn2 =
979 new InterpreterMultiNode(/*both_ops_supported*/ true);
TEST(ModelBuilderTest,GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions)980 TEST(ModelBuilderTest,
981 GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions) {
982 // A graph with three Dequant nodes feeding two Add ops.
983 // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Add(3) -> t6
984 // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
985 // --\
986 // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
987 //
988 // In this test case, we purposely partition Add(3) & Add(4) into different
989 // partitions, to check if Dequant nodes that output *only* to the first
990 // partition nodes are accepted.
991
992 TfLiteContext* context = interpreter_mn2->context();
993
994 // These functions are meant to be called inside delegates. Swap out
995 // for similar functions to permit direct calling of GetOpsToReplace.
996 context->GetExecutionPlan = [](struct TfLiteContext* context,
997 TfLiteIntArray** execution_plan) {
998 *execution_plan = interpreter_mn2->exec_plan();
999 return kTfLiteOk;
1000 };
1001 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
1002 TfLiteNode** node,
1003 TfLiteRegistration** registration) {
1004 *node = interpreter_mn2->node(node_index);
1005 *registration = interpreter_mn2->registration(node_index);
1006 return kTfLiteOk;
1007 };
1008
1009 context->PreviewDelegatePartitioning =
1010 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
1011 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
1012 // The FP16GraphPartitioner should only mark both ADD ops as accepted.
1013 EXPECT_EQ(nodes_to_replace->size, 2);
1014 EXPECT_EQ(nodes_to_replace->data[0], 3);
1015 EXPECT_EQ(nodes_to_replace->data[1], 4);
1016 // Technically, both ADD ops should end up in the same partition.
1017 // But we put them in different partitions to test post-processing with
1018 // DEQUANTIZE nodes.
1019 // First partition with Add(3).
1020 auto params = interpreter_mn2->add_delegate_params();
1021 params->nodes_to_replace = TfLiteIntArrayCreate(1);
1022 params->nodes_to_replace->data[0] = 3;
1023 params->input_tensors = TfLiteIntArrayCreate(2);
1024 params->input_tensors->data[0] = 3;
1025 params->input_tensors->data[1] = 4;
1026 params->output_tensors = TfLiteIntArrayCreate(1);
1027 params->output_tensors->data[0] = 6;
1028 // Second partition with Add(4).
1029 params = interpreter_mn2->add_delegate_params();
1030 params->nodes_to_replace = TfLiteIntArrayCreate(1);
1031 params->nodes_to_replace->data[0] = 4;
1032 params->input_tensors = TfLiteIntArrayCreate(2);
1033 params->input_tensors->data[0] = 4;
1034 params->input_tensors->data[1] = 5;
1035 params->output_tensors = TfLiteIntArrayCreate(1);
1036 params->output_tensors->data[0] = 7;
1037
1038 *partition_params_array = interpreter_mn2->delegate_params();
1039 *num_partitions = interpreter_mn2->num_delegate_params();
1040 return kTfLiteOk;
1041 };
1042
1043 TfLiteIntArray* ops_to_replace = GetOpsToReplace(
1044 context, /*allow_quant_ops*/ false, /*max_delegated_partitions*/ 2);
1045
1046 // Three ops should be selected:
1047 // Add(3), Dequant(x), Add(4)
1048 // Since both partitions are of size 1, either could end up as the 'first'
1049 // partition with one Dequant node added for it.
1050 EXPECT_EQ(ops_to_replace->size, 3);
1051
1052 TfLiteNode* node = nullptr;
1053 TfLiteRegistration* registration = nullptr;
1054 // Verify that both Add ops have fp16 inputs.
1055 context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
1056 ®istration);
1057 EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
1058 TfLiteType::kTfLiteFloat16);
1059 EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
1060 TfLiteType::kTfLiteFloat16);
1061 context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
1062 ®istration);
1063 EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
1064 TfLiteType::kTfLiteFloat16);
1065 EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
1066 TfLiteType::kTfLiteFloat16);
1067 // Verify that the op at index 1 is a Dequant outputing to a single Add.
1068 EXPECT_TRUE(ops_to_replace->data[1] == 0 || ops_to_replace->data[1] == 2);
1069 TfLiteIntArrayFree(ops_to_replace);
1070 }
1071
1072 // Adds the pattern:
1073 //
1074 // float -> QUANTIZE -> ADD -> DEQUANTIZE -> float
1075 // float -> QUANTIZE ----^
1076 //
1077 // The tensors between the QUANTIZE & DEQUANTIZE nodes are int8.
1078 class InterpreterQuantized : public DelegatedInterpreter {
1079 public:
InterpreterQuantized()1080 InterpreterQuantized() : DelegatedInterpreter(4) {
1081 void* builtin_data = malloc(sizeof(int));
1082 EXPECT_EQ(interpreter_.AddTensors(6), kTfLiteOk);
1083 EXPECT_EQ(interpreter_.SetInputs({0, 3}), kTfLiteOk);
1084 EXPECT_EQ(interpreter_.SetOutputs({5}), kTfLiteOk);
1085
1086 // QUANTIZE 1
1087 const TfLiteRegistration reg_quant0 = {/*init=*/nullptr,
1088 /*free=*/nullptr,
1089 /*prepare=*/nullptr,
1090 /*invoke=*/nullptr,
1091 /*profiling_string=*/nullptr,
1092 kTfLiteBuiltinQuantize};
1093 EXPECT_EQ(interpreter_.AddNodeWithParameters(
1094 /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
1095 /*init_data_size=*/0, /*builtin_data=*/nullptr,
1096 /*registration=*/®_quant0),
1097 kTfLiteOk);
1098
1099 // QUANTIZE 2
1100 const TfLiteRegistration reg_quant1 = {/*init=*/nullptr,
1101 /*free=*/nullptr,
1102 /*prepare=*/nullptr,
1103 /*invoke=*/nullptr,
1104 /*profiling_string=*/nullptr,
1105 kTfLiteBuiltinQuantize};
1106 EXPECT_EQ(interpreter_.AddNodeWithParameters(
1107 /*inputs=*/{3}, /*outputs=*/{2}, /*init_data=*/nullptr,
1108 /*init_data_size=*/0, /*builtin_data=*/nullptr,
1109 /*registration=*/®_quant1),
1110 kTfLiteOk);
1111
1112 // ADD
1113 const TfLiteRegistration reg_add0 = {
1114 [](TfLiteContext* context, const char* buffer, size_t length) {
1115 return reinterpret_cast<void*>(new int(1));
1116 },
1117 [](TfLiteContext* context, void* buffer) {
1118 delete reinterpret_cast<int*>(buffer);
1119 },
1120 nullptr,
1121 nullptr,
1122 nullptr,
1123 kTfLiteBuiltinAdd};
1124 EXPECT_EQ(interpreter_.AddNodeWithParameters(
1125 /*inputs=*/{1, 2}, /*outputs=*/{4}, /*init_data=*/nullptr,
1126 /*init_data_size=*/0,
1127 /*builtin_data=*/builtin_data,
1128 /*registration=*/®_add0),
1129 kTfLiteOk);
1130
1131 // DEQUANTIZE
1132 const TfLiteRegistration reg_dequant0 = {/*init=*/nullptr,
1133 /*free=*/nullptr,
1134 /*prepare=*/nullptr,
1135 /*invoke=*/nullptr,
1136 /*profiling_string=*/nullptr,
1137 kTfLiteBuiltinDequantize};
1138 EXPECT_EQ(interpreter_.AddNodeWithParameters(
1139 /*inputs=*/{4}, /*outputs=*/{5}, /*init_data=*/nullptr,
1140 /*init_data_size=*/0, /*builtin_data=*/nullptr,
1141 /*registration=*/®_dequant0),
1142 kTfLiteOk);
1143
1144 const std::vector<int> dims = {1, 3, 3, 2};
1145
1146 // Input & output tensors are floating-point.
1147 TfLiteQuantization no_quantization;
1148 no_quantization.type = kTfLiteNoQuantization;
1149 EXPECT_EQ(
1150 interpreter_.SetTensorParametersReadWrite(
1151 0, TfLiteType::kTfLiteFloat32, "t0", dims, no_quantization, false),
1152 kTfLiteOk);
1153 EXPECT_EQ(
1154 interpreter_.SetTensorParametersReadWrite(
1155 3, TfLiteType::kTfLiteFloat32, "t3", dims, no_quantization, false),
1156 kTfLiteOk);
1157 EXPECT_EQ(
1158 interpreter_.SetTensorParametersReadWrite(
1159 5, TfLiteType::kTfLiteFloat32, "t5", dims, no_quantization, false),
1160 kTfLiteOk);
1161 // Other tensors are int8.
1162 float scale = 0.5f;
1163 int32_t zero_point = 12;
1164 TfLiteQuantization rw_quantization;
1165 rw_quantization.type = kTfLiteAffineQuantization;
1166 auto* rw_affine_quantization = static_cast<TfLiteAffineQuantization*>(
1167 malloc(sizeof(TfLiteAffineQuantization)));
1168 rw_affine_quantization->scale = TfLiteFloatArrayCreate(1);
1169 rw_affine_quantization->zero_point = TfLiteIntArrayCreate(1);
1170 rw_affine_quantization->scale->data[0] = scale;
1171 rw_affine_quantization->zero_point->data[0] = zero_point;
1172 rw_quantization.params = rw_affine_quantization;
1173 EXPECT_EQ(
1174 interpreter_.SetTensorParametersReadWrite(
1175 1, TfLiteType::kTfLiteInt8, "t1", dims, rw_quantization, false),
1176 kTfLiteOk);
1177 EXPECT_EQ(
1178 interpreter_.SetTensorParametersReadWrite(
1179 2, TfLiteType::kTfLiteInt8, "t2", dims, rw_quantization, false),
1180 kTfLiteOk);
1181 EXPECT_EQ(
1182 interpreter_.SetTensorParametersReadWrite(
1183 4, TfLiteType::kTfLiteInt8, "t4", dims, rw_quantization, false),
1184 kTfLiteOk);
1185
1186 exec_plan()->data[0] = 0;
1187 exec_plan()->data[1] = 1;
1188 exec_plan()->data[2] = 2;
1189 exec_plan()->data[3] = 3;
1190 }
1191 };
1192
1193 InterpreterQuantized* interpreter_quant = new InterpreterQuantized();
TEST(ModelBuilderTest,GetOpsToReplace_AllowQuantOps)1194 TEST(ModelBuilderTest, GetOpsToReplace_AllowQuantOps) {
1195 TfLiteContext* context = interpreter_quant->context();
1196
1197 // These functions are meant to be called inside delegates. Swap out
1198 // for similar functions to permit direct calling of GetOpsToReplace.
1199 context->GetExecutionPlan = [](struct TfLiteContext* context,
1200 TfLiteIntArray** execution_plan) {
1201 *execution_plan = interpreter_quant->exec_plan();
1202 return kTfLiteOk;
1203 };
1204 context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
1205 TfLiteNode** node,
1206 TfLiteRegistration** registration) {
1207 *node = interpreter_quant->node(node_index);
1208 *registration = interpreter_quant->registration(node_index);
1209 return kTfLiteOk;
1210 };
1211 context->PreviewDelegatePartitioning =
1212 [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
1213 TfLiteDelegateParams** partition_params_array, int* num_partitions) {
1214 if (nodes_to_replace->size == 0) {
1215 *num_partitions = 0;
1216 return kTfLiteOk;
1217 }
1218 auto params = interpreter_quant->add_delegate_params();
1219 params->nodes_to_replace = TfLiteIntArrayCreate(3);
1220 params->nodes_to_replace->data[0] = 0;
1221 params->nodes_to_replace->data[1] = 1;
1222 params->nodes_to_replace->data[2] = 2;
1223 params->input_tensors = TfLiteIntArrayCreate(2);
1224 params->input_tensors->data[0] = 0;
1225 params->input_tensors->data[1] = 3;
1226 params->output_tensors = TfLiteIntArrayCreate(1);
1227 params->output_tensors->data[0] = 4;
1228
1229 *partition_params_array = interpreter_quant->delegate_params();
1230 *num_partitions = interpreter_quant->num_delegate_params();
1231 return kTfLiteOk;
1232 };
1233
1234 TfLiteIntArray* ops_to_replace =
1235 GetOpsToReplace(context, /**allow_quant_ops=*/true);
1236 // If we allow quant ops, two QUANTIZE & one ADD node should be accepted.
1237 EXPECT_EQ(ops_to_replace->size, 3);
1238 EXPECT_EQ(0, ops_to_replace->data[0]);
1239 EXPECT_EQ(1, ops_to_replace->data[1]);
1240 EXPECT_EQ(2, ops_to_replace->data[2]);
1241
1242 TfLiteIntArray* ops_to_replace_without_quant =
1243 GetOpsToReplace(context, /**allow_quant_ops=*/false);
1244 // No ops should be accepted.
1245 EXPECT_EQ(ops_to_replace_without_quant->size, 0);
1246
1247 TfLiteIntArrayFree(ops_to_replace);
1248 TfLiteIntArrayFree(ops_to_replace_without_quant);
1249 }
1250
1251 } // namespace
1252 } // namespace gpu
1253 } // namespace tflite
1254