• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #include <cstdlib>
22 #include <utility>
23 #include <vector>
24 
25 #include <gtest/gtest.h>
26 #include "absl/status/status.h"
27 #include "tensorflow/lite/builtin_ops.h"
28 #include "tensorflow/lite/core/subgraph.h"
29 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32 #include "tensorflow/lite/interpreter.h"
33 
34 namespace tflite {
35 namespace gpu {
36 namespace {
37 
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank0)38 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank0) {
39   TfLiteTensor tflite_tensor;
40   tflite_tensor.name = "tensor_name";
41   tflite_tensor.type = TfLiteType::kTfLiteFloat32;
42   tflite_tensor.dims = TfLiteIntArrayCreate(1);
43   tflite_tensor.dims->data[0] = 4;
44   TensorRef<BHWC> tensor_ref;
45   const auto status =
46       ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
47   TfLiteIntArrayFree(tflite_tensor.dims);
48   ASSERT_TRUE(status.ok());
49   EXPECT_EQ(tensor_ref.type, DataType::FLOAT32);
50   EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 1));
51 }
52 
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank1)53 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank1) {
54   TfLiteTensor tflite_tensor;
55   tflite_tensor.name = "tensor_name";
56   tflite_tensor.type = TfLiteType::kTfLiteInt32;
57   tflite_tensor.dims = TfLiteIntArrayCreate(2);
58   tflite_tensor.dims->data[0] = 4;
59   tflite_tensor.dims->data[1] = 5;
60   TensorRef<BHWC> tensor_ref;
61   const auto status =
62       ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
63   TfLiteIntArrayFree(tflite_tensor.dims);
64   ASSERT_TRUE(status.ok());
65   EXPECT_EQ(tensor_ref.type, DataType::INT32);
66   EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 5));
67 }
68 
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank2)69 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank2) {
70   TfLiteTensor tflite_tensor;
71   tflite_tensor.name = "tensor_name";
72   tflite_tensor.type = TfLiteType::kTfLiteInt64;
73   tflite_tensor.dims = TfLiteIntArrayCreate(3);
74   tflite_tensor.dims->data[0] = 4;
75   tflite_tensor.dims->data[1] = 5;
76   tflite_tensor.dims->data[2] = 6;
77   TensorRef<BHWC> tensor_ref;
78   const auto status =
79       ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
80   TfLiteIntArrayFree(tflite_tensor.dims);
81   ASSERT_TRUE(status.ok());
82   EXPECT_EQ(tensor_ref.type, DataType::INT64);
83   EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 5, 6));
84 }
85 
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefSucceedsForRank3)86 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank3) {
87   TfLiteTensor tflite_tensor;
88   tflite_tensor.name = "tensor_name";
89   tflite_tensor.type = TfLiteType::kTfLiteUInt8;
90   tflite_tensor.dims = TfLiteIntArrayCreate(4);
91   tflite_tensor.dims->data[0] = 4;
92   tflite_tensor.dims->data[1] = 5;
93   tflite_tensor.dims->data[2] = 6;
94   tflite_tensor.dims->data[3] = 7;
95   TensorRef<BHWC> tensor_ref;
96   const auto status =
97       ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
98   TfLiteIntArrayFree(tflite_tensor.dims);
99   ASSERT_TRUE(status.ok());
100   EXPECT_EQ(tensor_ref.type, DataType::UINT8);
101   EXPECT_EQ(tensor_ref.shape, BHWC(4, 5, 6, 7));
102 }
103 
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefFailsForRankLT0)104 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankLT0) {
105   TfLiteTensor tflite_tensor;
106   tflite_tensor.name = "tensor_name";
107   tflite_tensor.type = TfLiteType::kTfLiteFloat32;
108   tflite_tensor.dims = TfLiteIntArrayCreate(0);
109   TensorRef<BHWC> tensor_ref;
110   const auto status =
111       ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
112   TfLiteIntArrayFree(tflite_tensor.dims);
113   // TODO(b/130054481): Cover scalar.
114   EXPECT_FALSE(status.ok());
115 }
116 
TEST(ModelBuilderTest,ConvertTfLiteTensorToTensorRefFailsForRankGT3)117 TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) {
118   TfLiteTensor tflite_tensor;
119   tflite_tensor.name = "tensor_name";
120   tflite_tensor.type = TfLiteType::kTfLiteFloat32;
121   tflite_tensor.dims = TfLiteIntArrayCreate(5);
122   TensorRef<BHWC> tensor_ref;
123   const auto status =
124       ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
125   TfLiteIntArrayFree(tflite_tensor.dims);
126   EXPECT_FALSE(status.ok());
127 }
128 
129 class DelegatedInterpreter {
130  public:
DelegatedInterpreter(int num_nodes)131   explicit DelegatedInterpreter(int num_nodes) {
132     exec_plan_ = TfLiteIntArrayCreate(num_nodes);
133   }
~DelegatedInterpreter()134   virtual ~DelegatedInterpreter() {
135     TfLiteIntArrayFree(exec_plan_);
136     for (auto params : delegate_params_) {
137       TfLiteIntArrayFree(params.nodes_to_replace);
138       TfLiteIntArrayFree(params.input_tensors);
139       TfLiteIntArrayFree(params.output_tensors);
140     }
141   }
142 
143   // Get the TfLiteContext to be mocked for swapping out functions that have to
144   // be called inside delegate (i.e. in delegate kernel mode).
context()145   TfLiteContext* context() { return interpreter_.primary_subgraph().context(); }
146 
147   // node(int) and registration(int) are used to implement
148   // GetNodeAndRegistration.  We can't implement those using
149   //   TfLiteContext *context = interpreter_.primary_subgraph().context();
150   //   context->GetNodeAndRegistration(context, &node, &registration);
151   // here, because calling GetNodeAndRegistration from within it's own
152   // implementation would lead to an infinite loop.
153   // Instead, we just call node_and_registration and use a const_cast.
154   // These const_casts are a bit ugly, but I think less ugly than exposing
155   // the private GetNodeAndRegistration method in Subgraph as public,
156   // or making this class a friend of Subgraph.
node(int index)157   TfLiteNode* node(int index) {
158     const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
159         interpreter_.primary_subgraph().node_and_registration(index);
160     return const_cast<TfLiteNode*>(&node_and_registration->first);
161   }
registration(int index)162   TfLiteRegistration* registration(int index) {
163     const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
164         interpreter_.primary_subgraph().node_and_registration(index);
165     return const_cast<TfLiteRegistration*>(&node_and_registration->second);
166   }
167 
exec_plan() const168   TfLiteIntArray* exec_plan() const { return exec_plan_; }
add_delegate_params()169   TfLiteDelegateParams* add_delegate_params() {
170     delegate_params_.push_back(TfLiteDelegateParams());
171     return &delegate_params_.back();
172   }
delegate_params()173   TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); }
num_delegate_params()174   int num_delegate_params() { return delegate_params_.size(); }
175 
176  protected:
177   Interpreter interpreter_;
178 
179  private:
180   // The manually-set execution plan for this delegated interpreter.
181   TfLiteIntArray* exec_plan_;
182 
183   // The TfLiteDelegateParams object that's manually populated inside the mocked
184   // TfLiteContext::PreviewDelegatePartitioning.
185   std::vector<TfLiteDelegateParams> delegate_params_;
186 };
187 
188 class InterpreterFp16 : public DelegatedInterpreter {
189  public:
InterpreterFp16(TfLiteBuiltinOperator op,bool const_dequantize_inputs=true)190   explicit InterpreterFp16(TfLiteBuiltinOperator op,
191                            bool const_dequantize_inputs = true)
192       : DelegatedInterpreter(3) {
193     void* builtin_data = malloc(sizeof(int));
194     EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk);
195     EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
196     EXPECT_EQ(interpreter_.SetOutputs({4}), kTfLiteOk);
197 
198     // Add a Dequantize Node.
199     const TfLiteRegistration reg_dequant0 = {
200         nullptr, nullptr, nullptr, nullptr, nullptr, kTfLiteBuiltinDequantize};
201     EXPECT_EQ(interpreter_.AddNodeWithParameters(
202                   /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
203                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
204                   /*registration=*/&reg_dequant0),
205               kTfLiteOk);
206 
207     // Add a Dequantize Node.
208     const TfLiteRegistration reg_dequant1 = {
209         nullptr, nullptr, nullptr, nullptr, nullptr, kTfLiteBuiltinDequantize};
210     EXPECT_EQ(interpreter_.AddNodeWithParameters(
211                   /*inputs=*/{2}, /*outputs=*/{3}, /*init_data=*/nullptr,
212                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
213                   /*registration=*/&reg_dequant1),
214               kTfLiteOk);
215 
216     // Add a node that GPU delegate can parse.
217     const TfLiteRegistration reg_op0 = {
218         [](TfLiteContext* context, const char* buffer, size_t length) {
219           return reinterpret_cast<void*>(new int(1));
220         },
221         [](TfLiteContext* context, void* buffer) {
222           delete reinterpret_cast<int*>(buffer);
223         },
224         nullptr,
225         nullptr,
226         nullptr,
227         op};
228     EXPECT_EQ(interpreter_.AddNodeWithParameters(
229                   /*inputs=*/{1, 3}, /*outputs=*/{4}, /*init_data=*/nullptr,
230                   /*init_data_size=*/0,
231                   /*builtin_data=*/builtin_data,
232                   /*registration=*/&reg_op0),
233               kTfLiteOk);
234 
235     // Set inputs to Dequantize node to the fp16 type, and outputs
236     // to fp32 type.
237     const std::vector<int> dims = {1};
238     TfLiteQuantization quantization;
239     quantization.type = kTfLiteNoQuantization;
240     EXPECT_EQ(
241         interpreter_.SetTensorParametersReadWrite(
242             0, TfLiteType::kTfLiteFloat16, "t0", dims, quantization, false),
243         kTfLiteOk);
244     EXPECT_EQ(
245         interpreter_.SetTensorParametersReadWrite(
246             2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
247         kTfLiteOk);
248     if (const_dequantize_inputs) {
249       // This simulates the dequantize inputs being constants in the graph.
250       // If this is not true, FP16GraphPartitionHelper should not consider the
251       // corresponding DEQUANTIZE ops.
252       auto* tensor0 = interpreter_.tensor(0);
253       auto* tensor2 = interpreter_.tensor(2);
254       tensor0->allocation_type = kTfLiteMmapRo;
255       tensor2->allocation_type = kTfLiteMmapRo;
256     }
257     EXPECT_EQ(
258         interpreter_.SetTensorParametersReadWrite(
259             1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
260         kTfLiteOk);
261     EXPECT_EQ(
262         interpreter_.SetTensorParametersReadWrite(
263             3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
264         kTfLiteOk);
265 
266     exec_plan()->data[0] = 0;
267     exec_plan()->data[1] = 1;
268     exec_plan()->data[2] = 2;
269   }
270 };
271 
272 // **NOTE**: we have several interpreter instances created at global scope to
273 // test *exactly* the GetOpsToReplace function alone, and not the sequence of
274 // function calls that includes GetOpsToReplace when calling
275 // ModifyGraphWithDelegate. A TfLiteContext is needed to test GetOpsToReplace,
276 // but TfLiteContexts intentionally make it difficult to call certain functions
277 // in a non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for
278 // details) We create our own GetExecutionPlan, GetNodeAndRegistration and
279 // PreviewDelegatePartitioning lambdas inside each test, but we can't use local
280 // captures without changing the function signature. Therefore, this test data
281 // lives at global scope in order to be accessible inside the lambda.
282 
283 InterpreterFp16* interpreter_fp16_add_op =
284     new InterpreterFp16(kTfLiteBuiltinAdd);
285 
TEST(ModelBuilderTest,GetOpsToReplaceAcceptsFp16DequantizeNodes)286 TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
287   // Before pruning, the graph has three nodes:
288   //
289   //   t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
290   //   t2 (FP16) -> DequantNode -> t3 (FP32) --/
291   //
292   // OpsToReplace should choose all three nodes for replacement, and
293   // the graph on the GPU will look like this (no Dequants):
294   //
295   //   t0 (FP16) --> Add -> t4
296   //   t2 (FP16) --/
297   //
298   TfLiteContext* context = interpreter_fp16_add_op->context();
299 
300   // These functions are meant to be called inside delegates. Swap out
301   // for similar functions to permit direct calling of GetOpsToReplace.
302   context->GetExecutionPlan = [](struct TfLiteContext* context,
303                                  TfLiteIntArray** execution_plan) {
304     *execution_plan = interpreter_fp16_add_op->exec_plan();
305     return kTfLiteOk;
306   };
307   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
308                                        TfLiteNode** node,
309                                        TfLiteRegistration** registration) {
310     *node = interpreter_fp16_add_op->node(node_index);
311     *registration = interpreter_fp16_add_op->registration(node_index);
312     return kTfLiteOk;
313   };
314   context->PreviewDelegatePartitioning =
315       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
316          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
317         // The partitioner should accept only the Add op initially.
318         EXPECT_EQ(nodes_to_replace->size, 1);
319         // Single partition output.
320         auto params = interpreter_fp16_add_op->add_delegate_params();
321         params->nodes_to_replace = TfLiteIntArrayCreate(1);
322         params->nodes_to_replace->data[0] = 2;
323         params->input_tensors = TfLiteIntArrayCreate(2);
324         params->input_tensors->data[0] = 1;
325         params->input_tensors->data[1] = 3;
326         params->output_tensors = TfLiteIntArrayCreate(1);
327         params->output_tensors->data[0] = 4;
328 
329         *partition_params_array = interpreter_fp16_add_op->delegate_params();
330         *num_partitions = interpreter_fp16_add_op->num_delegate_params();
331         return kTfLiteOk;
332       };
333 
334   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
335 
336   // The Dequant nodes are added to ops_to_replace as a post-processing step by
337   // the FP16GraphPartitioner. ADD is delegated with its inputs pointing to the
338   // FP16 inputs.
339   EXPECT_EQ(ops_to_replace->size, 3);
340   TfLiteNode* node = nullptr;
341   TfLiteRegistration* registration = nullptr;
342   context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
343                                   &registration);
344   EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
345             TfLiteType::kTfLiteFloat16);
346   EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
347             TfLiteType::kTfLiteFloat16);
348   TfLiteIntArrayFree(ops_to_replace);
349 }
350 
351 InterpreterFp16* interpreter_fp16_non_constant =
352     new InterpreterFp16(kTfLiteBuiltinAdd, /*const_dequantize_inputs=*/false);
353 
354 // Same as GetOpsToReplaceAcceptsFp16DequantizeNodes, but the DEQUANTIZE inputs
355 // are not constant. As a result, we don't allow the delegate to accept them.
TEST(ModelBuilderTest,GetOpsToReplaceRejectsNonConstantFp16DequantizeNodes)356 TEST(ModelBuilderTest, GetOpsToReplaceRejectsNonConstantFp16DequantizeNodes) {
357   TfLiteContext* context = interpreter_fp16_non_constant->context();
358 
359   // These functions are meant to be called inside delegates. Swap out
360   // for similar functions to permit direct calling of GetOpsToReplace.
361   context->GetExecutionPlan = [](struct TfLiteContext* context,
362                                  TfLiteIntArray** execution_plan) {
363     *execution_plan = interpreter_fp16_non_constant->exec_plan();
364     return kTfLiteOk;
365   };
366   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
367                                        TfLiteNode** node,
368                                        TfLiteRegistration** registration) {
369     *node = interpreter_fp16_non_constant->node(node_index);
370     *registration = interpreter_fp16_non_constant->registration(node_index);
371     return kTfLiteOk;
372   };
373   context->PreviewDelegatePartitioning =
374       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
375          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
376         // The partitioner should accept only the Add op initially.
377         EXPECT_EQ(nodes_to_replace->size, 1);
378         // Single partition output.
379         auto params = interpreter_fp16_non_constant->add_delegate_params();
380         params->nodes_to_replace = TfLiteIntArrayCreate(1);
381         params->nodes_to_replace->data[0] = 2;
382         params->input_tensors = TfLiteIntArrayCreate(2);
383         params->input_tensors->data[0] = 1;
384         params->input_tensors->data[1] = 3;
385         params->output_tensors = TfLiteIntArrayCreate(1);
386         params->output_tensors->data[0] = 4;
387 
388         *partition_params_array =
389             interpreter_fp16_non_constant->delegate_params();
390         *num_partitions = interpreter_fp16_non_constant->num_delegate_params();
391         return kTfLiteOk;
392       };
393 
394   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
395 
396   // Only ADD is delegated, with FP32 (dequantized) inputs.
397   EXPECT_EQ(ops_to_replace->size, 1);
398   TfLiteNode* node = nullptr;
399   TfLiteRegistration* registration = nullptr;
400   context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
401                                   &registration);
402   EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
403             TfLiteType::kTfLiteFloat32);
404   EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
405             TfLiteType::kTfLiteFloat32);
406   TfLiteIntArrayFree(ops_to_replace);
407 }
408 
409 InterpreterFp16* interpreter_fp16_gt_op =
410     new InterpreterFp16(kTfLiteBuiltinGreater);
411 
TEST(ModelBuilderTest,GetOpsToReplaceRejectsFp16DequantizeNodes)412 TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) {
413   // Before pruning, the graph has three nodes:
414   //
415   //   t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4
416   //   t2 (FP16) -> DequantNode -> t3 (FP32) --/
417   //
418   // Because there is no GPU equivalent for the Greater op, we don't choose any
419   // nodes.
420 
421   TfLiteContext* context = interpreter_fp16_gt_op->context();
422   // These functions are meant to be called inside delegates. Swap out
423   // for similar functions to permit direct calling of GetOpsToReplace.
424   context->GetExecutionPlan = [](struct TfLiteContext* context,
425                                  TfLiteIntArray** execution_plan) {
426     *execution_plan = interpreter_fp16_gt_op->exec_plan();
427     return kTfLiteOk;
428   };
429   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
430                                        TfLiteNode** node,
431                                        TfLiteRegistration** registration) {
432     *node = interpreter_fp16_gt_op->node(node_index);
433     *registration = interpreter_fp16_gt_op->registration(node_index);
434     return kTfLiteOk;
435   };
436   context->PreviewDelegatePartitioning =
437       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
438          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
439         // No selected nodes.
440         EXPECT_EQ(nodes_to_replace->size, 0);
441         *partition_params_array = nullptr;
442         *num_partitions = 0;
443         return kTfLiteOk;
444       };
445 
446   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
447 
448   // No nodes were found to replace.
449   EXPECT_EQ(ops_to_replace->size, 0);
450   // Inputs to Greater op are still fp32.
451   TfLiteNode* node = nullptr;
452   TfLiteRegistration* registration = nullptr;
453   const int kGreaterOpIndex = 2;
454   context->GetNodeAndRegistration(context, kGreaterOpIndex, &node,
455                                   &registration);
456   EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
457             TfLiteType::kTfLiteFloat32);
458   EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
459             TfLiteType::kTfLiteFloat32);
460   TfLiteIntArrayFree(ops_to_replace);
461 }
462 
463 class InterpreterFp32 : public DelegatedInterpreter {
464  public:
InterpreterFp32()465   InterpreterFp32() : DelegatedInterpreter(2) {
466     void* builtin_data = malloc(sizeof(int));
467     EXPECT_EQ(interpreter_.AddTensors(4), kTfLiteOk);
468     EXPECT_EQ(interpreter_.SetInputs({0, 2}), kTfLiteOk);
469     EXPECT_EQ(interpreter_.SetOutputs({3}), kTfLiteOk);
470 
471     // Add a Dequantize Node with uint8 input.
472     const TfLiteRegistration reg_dequant0 = {/*init=*/nullptr,
473                                              /*free=*/nullptr,
474                                              /*prepare=*/nullptr,
475                                              /*invoke=*/nullptr,
476                                              /*profiling_string=*/nullptr,
477                                              kTfLiteBuiltinDequantize};
478     EXPECT_EQ(interpreter_.AddNodeWithParameters(
479                   /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
480                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
481                   /*registration=*/&reg_dequant0),
482               kTfLiteOk);
483 
484     // Add a node that GPU delegate can parse.
485     const TfLiteRegistration reg_add0 = {
486         [](TfLiteContext* context, const char* buffer, size_t length) {
487           return reinterpret_cast<void*>(new int(1));
488         },
489         [](TfLiteContext* context, void* buffer) {
490           delete reinterpret_cast<int*>(buffer);
491         },
492         nullptr,
493         nullptr,
494         nullptr,
495         kTfLiteBuiltinAdd};
496     EXPECT_EQ(interpreter_.AddNodeWithParameters(
497                   /*inputs=*/{1, 2}, /*outputs=*/{3}, /*init_data=*/nullptr,
498                   /*init_data_size=*/0,
499                   /*builtin_data=*/builtin_data,
500                   /*registration=*/&reg_add0),
501               kTfLiteOk);
502 
503     const std::vector<int> dims = {1};
504     TfLiteQuantization quantization;
505     quantization.type = kTfLiteNoQuantization;
506     EXPECT_EQ(interpreter_.SetTensorParametersReadWrite(
507                   0, TfLiteType::kTfLiteUInt8, "t0", dims, quantization, false),
508               kTfLiteOk);
509     EXPECT_EQ(
510         interpreter_.SetTensorParametersReadWrite(
511             1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
512         kTfLiteOk);
513     EXPECT_EQ(
514         interpreter_.SetTensorParametersReadWrite(
515             2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false),
516         kTfLiteOk);
517 
518     exec_plan()->data[0] = 0;
519     exec_plan()->data[1] = 1;
520   }
521 };
522 
523 InterpreterFp32* interpreter_fp32 = new InterpreterFp32();
524 
TEST(ModelBuilderTest,GetOpsToReplaceDoesNotPruneUint8)525 TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
526   // A graph with a Dequant node with uint8 input is not pruned. As this op is
527   // currently not supported on the GPU. Therefore, the Dequant op will be
528   // scheduled to run on the CPU while the remaining supported op Add on the
529   // GPU.
530   //
531   //   t0 (uint8) --> Dequant --> t1 (FP32) --> Add -> t3
532   //                              t2 (FP32) --/
533   //
534   TfLiteContext* context = interpreter_fp32->context();
535 
536   // These functions are meant to be called inside delegates. Swap out
537   // for similar functions to permit direct calling of GetOpsToReplace.
538   context->GetExecutionPlan = [](struct TfLiteContext* context,
539                                  TfLiteIntArray** execution_plan) {
540     *execution_plan = interpreter_fp32->exec_plan();
541     return kTfLiteOk;
542   };
543   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
544                                        TfLiteNode** node,
545                                        TfLiteRegistration** registration) {
546     *node = interpreter_fp32->node(node_index);
547     *registration = interpreter_fp32->registration(node_index);
548     return kTfLiteOk;
549   };
550   context->PreviewDelegatePartitioning =
551       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
552          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
553         auto params = interpreter_fp32->add_delegate_params();
554         params->nodes_to_replace = TfLiteIntArrayCreate(1);
555         params->nodes_to_replace->data[0] = 1;
556         params->input_tensors = TfLiteIntArrayCreate(2);
557         params->input_tensors->data[0] = 1;
558         params->input_tensors->data[1] = 2;
559         params->output_tensors = TfLiteIntArrayCreate(1);
560         params->output_tensors->data[0] = 3;
561 
562         *partition_params_array = interpreter_fp32->delegate_params();
563         *num_partitions = interpreter_fp32->num_delegate_params();
564         return kTfLiteOk;
565       };
566 
567   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
568 
569   // As the Dequant op is not pruned and the ADD op could run on GPU, we have
570   // 1 partition.
571   EXPECT_EQ(ops_to_replace->size, 1);
572   // ADD at index 1.
573   EXPECT_EQ(1, ops_to_replace->data[0]);
574 
575   TfLiteIntArrayFree(ops_to_replace);
576 }
577 
578 class Interpreter2Fp32 : public DelegatedInterpreter {
579  public:
Interpreter2Fp32()580   Interpreter2Fp32() : DelegatedInterpreter(4) {
581     void* builtin_data = malloc(sizeof(int));
582     EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
583     EXPECT_EQ(interpreter_.SetInputs({0, 2, 4, 6}), kTfLiteOk);
584     EXPECT_EQ(interpreter_.SetOutputs({7}), kTfLiteOk);
585 
586     // Add a Dequantize Node with uint8 input.
587     const TfLiteRegistration reg_dequant = {/*init=*/nullptr,
588                                             /*free=*/nullptr,
589                                             /*prepare=*/nullptr,
590                                             /*invoke=*/nullptr,
591                                             /*profiling_string=*/nullptr,
592                                             kTfLiteBuiltinDequantize};
593     EXPECT_EQ(interpreter_.AddNodeWithParameters(
594                   /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
595                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
596                   /*registration=*/&reg_dequant),
597               kTfLiteOk);
598 
599     // Add an ADD node that GPU delegate can parse.
600     const TfLiteRegistration reg_add0 = {
601         [](TfLiteContext* context, const char* buffer, size_t length) {
602           return reinterpret_cast<void*>(new int(1));
603         },
604         [](TfLiteContext* context, void* buffer) {
605           delete reinterpret_cast<int*>(buffer);
606         },
607         nullptr,
608         nullptr,
609         nullptr,
610         kTfLiteBuiltinAdd};
611     EXPECT_EQ(interpreter_.AddNodeWithParameters(
612                   /*inputs=*/{1, 2}, /*outputs=*/{3}, /*init_data=*/nullptr,
613                   /*init_data_size=*/0,
614                   /*builtin_data=*/builtin_data,
615                   /*registration=*/&reg_add0),
616               kTfLiteOk);
617 
618     // Add a Pack Node that GPU delegate doesn't support
619     const TfLiteRegistration reg_pack = {/*init=*/nullptr,
620                                          /*free=*/nullptr,
621                                          /*prepare=*/nullptr,
622                                          /*invoke=*/nullptr,
623                                          /*profiling_string=*/nullptr,
624                                          kTfLiteBuiltinPack};
625     EXPECT_EQ(interpreter_.AddNodeWithParameters(
626                   /*inputs=*/{3, 4}, /*outputs=*/{5}, /*init_data=*/nullptr,
627                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
628                   /*registration=*/&reg_pack),
629               kTfLiteOk);
630 
631     const TfLiteRegistration reg_add1 = {
632         [](TfLiteContext* context, const char* buffer, size_t length) {
633           return reinterpret_cast<void*>(new int[2]);
634         },
635         [](TfLiteContext* context, void* buffer) {
636           delete reinterpret_cast<int*>(buffer);
637         },
638         nullptr,
639         nullptr,
640         nullptr,
641         kTfLiteBuiltinAdd};
642     EXPECT_EQ(interpreter_.AddNodeWithParameters(
643                   /*inputs=*/{5, 6}, /*outputs=*/{7}, /*init_data=*/nullptr,
644                   /*init_data_size=*/0,
645                   /*builtin_data=*/builtin_data,
646                   /*registration=*/&reg_add1),
647               kTfLiteOk);
648 
649     std::vector<int> dims = {1};
650     TfLiteQuantization quantization;
651     quantization.type = kTfLiteNoQuantization;
652     EXPECT_EQ(interpreter_.SetTensorParametersReadWrite(
653                   0, TfLiteType::kTfLiteUInt8, "t0", dims, quantization, false),
654               kTfLiteOk);
655     EXPECT_EQ(
656         interpreter_.SetTensorParametersReadWrite(
657             1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
658         kTfLiteOk);
659     EXPECT_EQ(
660         interpreter_.SetTensorParametersReadWrite(
661             2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false),
662         kTfLiteOk);
663     EXPECT_EQ(
664         interpreter_.SetTensorParametersReadWrite(
665             3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
666         kTfLiteOk);
667     EXPECT_EQ(
668         interpreter_.SetTensorParametersReadWrite(
669             4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false),
670         kTfLiteOk);
671 
672     dims.push_back(2);
673     EXPECT_EQ(
674         interpreter_.SetTensorParametersReadWrite(
675             5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
676         kTfLiteOk);
677     EXPECT_EQ(
678         interpreter_.SetTensorParametersReadWrite(
679             6, TfLiteType::kTfLiteFloat32, "t6", dims, quantization, false),
680         kTfLiteOk);
681 
682     exec_plan()->data[0] = 0;
683     exec_plan()->data[1] = 1;
684     exec_plan()->data[2] = 2;
685     exec_plan()->data[3] = 3;
686   }
687 };
688 
689 Interpreter2Fp32* interpreter2_fp32 = new Interpreter2Fp32();
690 
TEST(ModelBuilderTest,GetOpsToReplaceMultiplePartitions)691 TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) {
692   // A graph with a Dequant node with uint8 input, a Pack node are not pruned.
693   // As these ops are currently not supported on the GPU, they will be scheduled
694   // to run on the CPU while the remaining supported op Add on the GPU.
695   //
696   //   t0 (uint8) -> Dequant(0) -> t1 (FP32) -> Add(1) -> t3 (FP32) -> PACK (2)
697   //                               t2 (FP32) -/           t4 (FP32) -/
698   //   PACK (2) -> t5 (FP32) -> Add(3) -> t7
699   //            -> t6 (FP32) -/
700   //
701   TfLiteContext* context = interpreter2_fp32->context();
702 
703   // These functions are meant to be called inside delegates. Swap out
704   // for similar functions to permit direct calling of GetOpsToReplace.
705   context->GetExecutionPlan = [](struct TfLiteContext* context,
706                                  TfLiteIntArray** execution_plan) {
707     *execution_plan = interpreter2_fp32->exec_plan();
708     return kTfLiteOk;
709   };
710   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
711                                        TfLiteNode** node,
712                                        TfLiteRegistration** registration) {
713     *node = interpreter2_fp32->node(node_index);
714     *registration = interpreter2_fp32->registration(node_index);
715     return kTfLiteOk;
716   };
717   context->PreviewDelegatePartitioning =
718       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
719          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
720         auto params = interpreter2_fp32->add_delegate_params();
721         params->nodes_to_replace = TfLiteIntArrayCreate(1);
722         params->nodes_to_replace->data[0] = 1;
723         params->input_tensors = TfLiteIntArrayCreate(2);
724         params->input_tensors->data[0] = 1;
725         params->input_tensors->data[1] = 2;
726         params->output_tensors = TfLiteIntArrayCreate(1);
727         params->output_tensors->data[0] = 3;
728 
729         params = interpreter2_fp32->add_delegate_params();
730         params->nodes_to_replace = TfLiteIntArrayCreate(1);
731         params->nodes_to_replace->data[0] = 3;
732         params->input_tensors = TfLiteIntArrayCreate(2);
733         params->input_tensors->data[0] = 5;
734         params->input_tensors->data[1] = 6;
735         params->output_tensors = TfLiteIntArrayCreate(1);
736         params->output_tensors->data[0] = 7;
737 
738         *partition_params_array = interpreter2_fp32->delegate_params();
739         *num_partitions = interpreter2_fp32->num_delegate_params();
740         return kTfLiteOk;
741       };
742 
743   TfLiteIntArray* ops_to_replace = GetOpsToReplace(
744       context, /*allow_quant_ops=*/false, /*max_delegated_partitions*/ 2);
745 
746   // As the Dequant op is not pruned and the ADD op could run on GPU, we have
747   // 2 partitions.
748   EXPECT_EQ(ops_to_replace->size, 2);
749   // ADD at index 1.
750   EXPECT_EQ(1, ops_to_replace->data[0]);
751   // ADD at index 3.
752   EXPECT_EQ(3, ops_to_replace->data[1]);
753 
754   TfLiteIntArrayFree(ops_to_replace);
755 }
756 
757 class InterpreterMultiNode : public DelegatedInterpreter {
758  public:
InterpreterMultiNode(bool both_ops_supported=true)759   explicit InterpreterMultiNode(bool both_ops_supported = true)
760       : DelegatedInterpreter(5) {
761     void* builtin_data = malloc(sizeof(int));
762     EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
763     EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
764     EXPECT_EQ(interpreter_.SetOutputs({6, 7}), kTfLiteOk);
765 
766     // Add 3 Dequantize Nodes with float16 input.
767     for (int i = 0; i < 3; ++i) {
768       const TfLiteRegistration reg_dequant = {/*init=*/nullptr,
769                                               /*free=*/nullptr,
770                                               /*prepare=*/nullptr,
771                                               /*invoke=*/nullptr,
772                                               /*profiling_string=*/nullptr,
773                                               kTfLiteBuiltinDequantize};
774       EXPECT_EQ(interpreter_.AddNodeWithParameters(
775                     /*inputs=*/{i}, /*outputs=*/{i + 3}, /*init_data=*/nullptr,
776                     /*init_data_size=*/0, /*builtin_data=*/nullptr,
777                     /*registration=*/&reg_dequant),
778                 kTfLiteOk);
779     }
780 
781     if (both_ops_supported) {
782       // Add 2 ADD ops.
783       const TfLiteRegistration reg_add0 = {
784           [](TfLiteContext* context, const char* buffer, size_t length) {
785             return reinterpret_cast<void*>(new int(1));
786           },
787           [](TfLiteContext* context, void* buffer) {
788             delete reinterpret_cast<int*>(buffer);
789           },
790           nullptr,
791           nullptr,
792           nullptr,
793           kTfLiteBuiltinAdd};
794       EXPECT_EQ(interpreter_.AddNodeWithParameters(
795                     /*inputs=*/{4, 5}, /*outputs=*/{7}, /*init_data=*/nullptr,
796                     /*init_data_size=*/0,
797                     /*builtin_data=*/builtin_data,
798                     /*registration=*/&reg_add0),
799                 kTfLiteOk);
800 
801       const TfLiteRegistration reg_add1 = {
802           [](TfLiteContext* context, const char* buffer, size_t length) {
803             return reinterpret_cast<void*>(new int(1));
804           },
805           [](TfLiteContext* context, void* buffer) {
806             delete reinterpret_cast<int*>(buffer);
807           },
808           nullptr,
809           nullptr,
810           nullptr,
811           kTfLiteBuiltinAdd};
812       EXPECT_EQ(interpreter_.AddNodeWithParameters(
813                     /*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
814                     /*init_data_size=*/0,
815                     /*builtin_data=*/builtin_data,
816                     /*registration=*/&reg_add1),
817                 kTfLiteOk);
818     } else {
819       // Add the GREATER op node that GPU delegate doesn't support.
820       const TfLiteRegistration reg_greater = {
821           [](TfLiteContext* context, const char* buffer, size_t length) {
822             return reinterpret_cast<void*>(new int(1));
823           },
824           [](TfLiteContext* context, void* buffer) {
825             delete reinterpret_cast<int*>(buffer);
826           },
827           nullptr,
828           nullptr,
829           nullptr,
830           kTfLiteBuiltinGreater};
831       EXPECT_EQ(interpreter_.AddNodeWithParameters(
832                     /*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
833                     /*init_data_size=*/0,
834                     /*builtin_data=*/builtin_data,
835                     /*registration=*/&reg_greater),
836                 kTfLiteOk);
837 
838       // Add the ADD op node that GPU delegate supports.
839       const TfLiteRegistration reg_add0 = {
840           [](TfLiteContext* context, const char* buffer, size_t length) {
841             return reinterpret_cast<void*>(new int(1));
842           },
843           [](TfLiteContext* context, void* buffer) {
844             delete reinterpret_cast<int*>(buffer);
845           },
846           nullptr,
847           nullptr,
848           nullptr,
849           kTfLiteBuiltinAdd};
850       EXPECT_EQ(interpreter_.AddNodeWithParameters(
851                     /*inputs=*/{4, 5}, /*outputs=*/{7}, /*init_data=*/nullptr,
852                     /*init_data_size=*/0,
853                     /*builtin_data=*/builtin_data,
854                     /*registration=*/&reg_add0),
855                 kTfLiteOk);
856     }
857     const std::vector<int> dims = {1};
858     TfLiteQuantization quantization;
859     quantization.type = kTfLiteNoQuantization;
860     EXPECT_EQ(
861         interpreter_.SetTensorParametersReadWrite(
862             0, TfLiteType::kTfLiteFloat16, "t0", dims, quantization, false),
863         kTfLiteOk);
864     EXPECT_EQ(
865         interpreter_.SetTensorParametersReadWrite(
866             1, TfLiteType::kTfLiteFloat16, "t1", dims, quantization, false),
867         kTfLiteOk);
868     EXPECT_EQ(
869         interpreter_.SetTensorParametersReadWrite(
870             2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
871         kTfLiteOk);
872     // Simulate DEQUANTIZE inputs being constants.
873     auto* tensor0 = interpreter_.tensor(0);
874     auto* tensor1 = interpreter_.tensor(1);
875     auto* tensor2 = interpreter_.tensor(2);
876     tensor0->allocation_type = kTfLiteMmapRo;
877     tensor1->allocation_type = kTfLiteMmapRo;
878     tensor2->allocation_type = kTfLiteMmapRo;
879     EXPECT_EQ(
880         interpreter_.SetTensorParametersReadWrite(
881             3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
882         kTfLiteOk);
883     EXPECT_EQ(
884         interpreter_.SetTensorParametersReadWrite(
885             4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false),
886         kTfLiteOk);
887     EXPECT_EQ(
888         interpreter_.SetTensorParametersReadWrite(
889             5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
890         kTfLiteOk);
891     EXPECT_EQ(
892         interpreter_.SetTensorParametersReadWrite(
893             6, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
894         kTfLiteOk);
895     EXPECT_EQ(
896         interpreter_.SetTensorParametersReadWrite(
897             7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
898         kTfLiteOk);
899 
900     exec_plan()->data[0] = 0;
901     exec_plan()->data[1] = 1;
902     exec_plan()->data[2] = 2;
903     exec_plan()->data[3] = 3;
904     exec_plan()->data[4] = 4;
905   }
906 };
907 
908 InterpreterMultiNode* interpreter_mn =
909     new InterpreterMultiNode(/*both_ops_supported*/ false);
910 
TEST(ModelBuilderTest,GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition)911 TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) {
912   // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
913   // 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
914   //   t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
915   //   t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
916   //                                          --\
917   //   t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
918   //
919   //  OpsToReplace should accept 'Add' & the Dequant nodes that only output to
920   //  it (in this case, Dequant(2)).
921   TfLiteContext* context = interpreter_mn->context();
922 
923   // These functions are meant to be called inside delegates. Swap out
924   // for similar functions to permit direct calling of GetOpsToReplace.
925   context->GetExecutionPlan = [](struct TfLiteContext* context,
926                                  TfLiteIntArray** execution_plan) {
927     *execution_plan = interpreter_mn->exec_plan();
928     return kTfLiteOk;
929   };
930   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
931                                        TfLiteNode** node,
932                                        TfLiteRegistration** registration) {
933     *node = interpreter_mn->node(node_index);
934     *registration = interpreter_mn->registration(node_index);
935     return kTfLiteOk;
936   };
937   context->PreviewDelegatePartitioning =
938       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
939          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
940         // The FP16GraphPartitioner should only mark the ADD op as accepted.
941         EXPECT_EQ(nodes_to_replace->size, 1);
942         EXPECT_EQ(nodes_to_replace->data[0], 4);
943         // Single partition.
944         auto params = interpreter_mn->add_delegate_params();
945         params->nodes_to_replace = TfLiteIntArrayCreate(1);
946         params->nodes_to_replace->data[0] = 4;
947         params->input_tensors = TfLiteIntArrayCreate(2);
948         params->input_tensors->data[0] = 1;
949         params->input_tensors->data[1] = 3;
950         params->output_tensors = TfLiteIntArrayCreate(1);
951         params->output_tensors->data[0] = 7;
952 
953         *partition_params_array = interpreter_mn->delegate_params();
954         *num_partitions = interpreter_mn->num_delegate_params();
955         return kTfLiteOk;
956       };
957 
958   TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
959 
960   // Post-PreviewDelegatePartitioning, the partitioner will add Dequant(2) to
961   // ops_to_replace, since it only outputs to a delegated node.
962   EXPECT_EQ(ops_to_replace->size, 2);
963   // Op at index 4 is the Add op.
964   EXPECT_EQ(ops_to_replace->data[0], 4);
965   EXPECT_EQ(ops_to_replace->data[1], 2);
966   // Verify that Add op has fp16 inputs.
967   TfLiteNode* node = nullptr;
968   TfLiteRegistration* registration = nullptr;
969   context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
970                                   &registration);
971   EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
972             TfLiteType::kTfLiteFloat16);
973   EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
974             TfLiteType::kTfLiteFloat16);
975   TfLiteIntArrayFree(ops_to_replace);
976 }
977 
978 InterpreterMultiNode* interpreter_mn2 =
979     new InterpreterMultiNode(/*both_ops_supported*/ true);
TEST(ModelBuilderTest,GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions)980 TEST(ModelBuilderTest,
981      GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions) {
982   // A graph with three Dequant nodes feeding two Add ops.
983   //   t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Add(3) -> t6
984   //   t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
985   //                                          --\
986   //   t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
987   //
988   // In this test case, we purposely partition Add(3) & Add(4) into different
989   // partitions, to check if Dequant nodes that output *only* to the first
990   // partition nodes are accepted.
991 
992   TfLiteContext* context = interpreter_mn2->context();
993 
994   // These functions are meant to be called inside delegates. Swap out
995   // for similar functions to permit direct calling of GetOpsToReplace.
996   context->GetExecutionPlan = [](struct TfLiteContext* context,
997                                  TfLiteIntArray** execution_plan) {
998     *execution_plan = interpreter_mn2->exec_plan();
999     return kTfLiteOk;
1000   };
1001   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
1002                                        TfLiteNode** node,
1003                                        TfLiteRegistration** registration) {
1004     *node = interpreter_mn2->node(node_index);
1005     *registration = interpreter_mn2->registration(node_index);
1006     return kTfLiteOk;
1007   };
1008 
1009   context->PreviewDelegatePartitioning =
1010       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
1011          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
1012         // The FP16GraphPartitioner should only mark both ADD ops as accepted.
1013         EXPECT_EQ(nodes_to_replace->size, 2);
1014         EXPECT_EQ(nodes_to_replace->data[0], 3);
1015         EXPECT_EQ(nodes_to_replace->data[1], 4);
1016         // Technically, both ADD ops should end up in the same partition.
1017         // But we put them in different partitions to test post-processing with
1018         // DEQUANTIZE nodes.
1019         // First partition with Add(3).
1020         auto params = interpreter_mn2->add_delegate_params();
1021         params->nodes_to_replace = TfLiteIntArrayCreate(1);
1022         params->nodes_to_replace->data[0] = 3;
1023         params->input_tensors = TfLiteIntArrayCreate(2);
1024         params->input_tensors->data[0] = 3;
1025         params->input_tensors->data[1] = 4;
1026         params->output_tensors = TfLiteIntArrayCreate(1);
1027         params->output_tensors->data[0] = 6;
1028         // Second partition with Add(4).
1029         params = interpreter_mn2->add_delegate_params();
1030         params->nodes_to_replace = TfLiteIntArrayCreate(1);
1031         params->nodes_to_replace->data[0] = 4;
1032         params->input_tensors = TfLiteIntArrayCreate(2);
1033         params->input_tensors->data[0] = 4;
1034         params->input_tensors->data[1] = 5;
1035         params->output_tensors = TfLiteIntArrayCreate(1);
1036         params->output_tensors->data[0] = 7;
1037 
1038         *partition_params_array = interpreter_mn2->delegate_params();
1039         *num_partitions = interpreter_mn2->num_delegate_params();
1040         return kTfLiteOk;
1041       };
1042 
1043   TfLiteIntArray* ops_to_replace = GetOpsToReplace(
1044       context, /*allow_quant_ops*/ false, /*max_delegated_partitions*/ 2);
1045 
1046   // Three ops should be selected:
1047   // Add(3), Dequant(x), Add(4)
1048   // Since both partitions are of size 1, either could end up as the 'first'
1049   // partition with one Dequant node added for it.
1050   EXPECT_EQ(ops_to_replace->size, 3);
1051 
1052   TfLiteNode* node = nullptr;
1053   TfLiteRegistration* registration = nullptr;
1054   // Verify that both Add ops have fp16 inputs.
1055   context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
1056                                   &registration);
1057   EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
1058             TfLiteType::kTfLiteFloat16);
1059   EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
1060             TfLiteType::kTfLiteFloat16);
1061   context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
1062                                   &registration);
1063   EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
1064             TfLiteType::kTfLiteFloat16);
1065   EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
1066             TfLiteType::kTfLiteFloat16);
1067   // Verify that the op at index 1 is a Dequant outputing to a single Add.
1068   EXPECT_TRUE(ops_to_replace->data[1] == 0 || ops_to_replace->data[1] == 2);
1069   TfLiteIntArrayFree(ops_to_replace);
1070 }
1071 
1072 // Adds the pattern:
1073 //
1074 // float -> QUANTIZE -> ADD -> DEQUANTIZE -> float
1075 // float -> QUANTIZE ----^
1076 //
1077 // The tensors between the QUANTIZE & DEQUANTIZE nodes are int8.
1078 class InterpreterQuantized : public DelegatedInterpreter {
1079  public:
InterpreterQuantized()1080   InterpreterQuantized() : DelegatedInterpreter(4) {
1081     void* builtin_data = malloc(sizeof(int));
1082     EXPECT_EQ(interpreter_.AddTensors(6), kTfLiteOk);
1083     EXPECT_EQ(interpreter_.SetInputs({0, 3}), kTfLiteOk);
1084     EXPECT_EQ(interpreter_.SetOutputs({5}), kTfLiteOk);
1085 
1086     // QUANTIZE 1
1087     const TfLiteRegistration reg_quant0 = {/*init=*/nullptr,
1088                                            /*free=*/nullptr,
1089                                            /*prepare=*/nullptr,
1090                                            /*invoke=*/nullptr,
1091                                            /*profiling_string=*/nullptr,
1092                                            kTfLiteBuiltinQuantize};
1093     EXPECT_EQ(interpreter_.AddNodeWithParameters(
1094                   /*inputs=*/{0}, /*outputs=*/{1}, /*init_data=*/nullptr,
1095                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
1096                   /*registration=*/&reg_quant0),
1097               kTfLiteOk);
1098 
1099     // QUANTIZE 2
1100     const TfLiteRegistration reg_quant1 = {/*init=*/nullptr,
1101                                            /*free=*/nullptr,
1102                                            /*prepare=*/nullptr,
1103                                            /*invoke=*/nullptr,
1104                                            /*profiling_string=*/nullptr,
1105                                            kTfLiteBuiltinQuantize};
1106     EXPECT_EQ(interpreter_.AddNodeWithParameters(
1107                   /*inputs=*/{3}, /*outputs=*/{2}, /*init_data=*/nullptr,
1108                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
1109                   /*registration=*/&reg_quant1),
1110               kTfLiteOk);
1111 
1112     // ADD
1113     const TfLiteRegistration reg_add0 = {
1114         [](TfLiteContext* context, const char* buffer, size_t length) {
1115           return reinterpret_cast<void*>(new int(1));
1116         },
1117         [](TfLiteContext* context, void* buffer) {
1118           delete reinterpret_cast<int*>(buffer);
1119         },
1120         nullptr,
1121         nullptr,
1122         nullptr,
1123         kTfLiteBuiltinAdd};
1124     EXPECT_EQ(interpreter_.AddNodeWithParameters(
1125                   /*inputs=*/{1, 2}, /*outputs=*/{4}, /*init_data=*/nullptr,
1126                   /*init_data_size=*/0,
1127                   /*builtin_data=*/builtin_data,
1128                   /*registration=*/&reg_add0),
1129               kTfLiteOk);
1130 
1131     // DEQUANTIZE
1132     const TfLiteRegistration reg_dequant0 = {/*init=*/nullptr,
1133                                              /*free=*/nullptr,
1134                                              /*prepare=*/nullptr,
1135                                              /*invoke=*/nullptr,
1136                                              /*profiling_string=*/nullptr,
1137                                              kTfLiteBuiltinDequantize};
1138     EXPECT_EQ(interpreter_.AddNodeWithParameters(
1139                   /*inputs=*/{4}, /*outputs=*/{5}, /*init_data=*/nullptr,
1140                   /*init_data_size=*/0, /*builtin_data=*/nullptr,
1141                   /*registration=*/&reg_dequant0),
1142               kTfLiteOk);
1143 
1144     const std::vector<int> dims = {1, 3, 3, 2};
1145 
1146     // Input & output tensors are floating-point.
1147     TfLiteQuantization no_quantization;
1148     no_quantization.type = kTfLiteNoQuantization;
1149     EXPECT_EQ(
1150         interpreter_.SetTensorParametersReadWrite(
1151             0, TfLiteType::kTfLiteFloat32, "t0", dims, no_quantization, false),
1152         kTfLiteOk);
1153     EXPECT_EQ(
1154         interpreter_.SetTensorParametersReadWrite(
1155             3, TfLiteType::kTfLiteFloat32, "t3", dims, no_quantization, false),
1156         kTfLiteOk);
1157     EXPECT_EQ(
1158         interpreter_.SetTensorParametersReadWrite(
1159             5, TfLiteType::kTfLiteFloat32, "t5", dims, no_quantization, false),
1160         kTfLiteOk);
1161     // Other tensors are int8.
1162     float scale = 0.5f;
1163     int32_t zero_point = 12;
1164     TfLiteQuantization rw_quantization;
1165     rw_quantization.type = kTfLiteAffineQuantization;
1166     auto* rw_affine_quantization = static_cast<TfLiteAffineQuantization*>(
1167         malloc(sizeof(TfLiteAffineQuantization)));
1168     rw_affine_quantization->scale = TfLiteFloatArrayCreate(1);
1169     rw_affine_quantization->zero_point = TfLiteIntArrayCreate(1);
1170     rw_affine_quantization->scale->data[0] = scale;
1171     rw_affine_quantization->zero_point->data[0] = zero_point;
1172     rw_quantization.params = rw_affine_quantization;
1173     EXPECT_EQ(
1174         interpreter_.SetTensorParametersReadWrite(
1175             1, TfLiteType::kTfLiteInt8, "t1", dims, rw_quantization, false),
1176         kTfLiteOk);
1177     EXPECT_EQ(
1178         interpreter_.SetTensorParametersReadWrite(
1179             2, TfLiteType::kTfLiteInt8, "t2", dims, rw_quantization, false),
1180         kTfLiteOk);
1181     EXPECT_EQ(
1182         interpreter_.SetTensorParametersReadWrite(
1183             4, TfLiteType::kTfLiteInt8, "t4", dims, rw_quantization, false),
1184         kTfLiteOk);
1185 
1186     exec_plan()->data[0] = 0;
1187     exec_plan()->data[1] = 1;
1188     exec_plan()->data[2] = 2;
1189     exec_plan()->data[3] = 3;
1190   }
1191 };
1192 
1193 InterpreterQuantized* interpreter_quant = new InterpreterQuantized();
TEST(ModelBuilderTest,GetOpsToReplace_AllowQuantOps)1194 TEST(ModelBuilderTest, GetOpsToReplace_AllowQuantOps) {
1195   TfLiteContext* context = interpreter_quant->context();
1196 
1197   // These functions are meant to be called inside delegates. Swap out
1198   // for similar functions to permit direct calling of GetOpsToReplace.
1199   context->GetExecutionPlan = [](struct TfLiteContext* context,
1200                                  TfLiteIntArray** execution_plan) {
1201     *execution_plan = interpreter_quant->exec_plan();
1202     return kTfLiteOk;
1203   };
1204   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
1205                                        TfLiteNode** node,
1206                                        TfLiteRegistration** registration) {
1207     *node = interpreter_quant->node(node_index);
1208     *registration = interpreter_quant->registration(node_index);
1209     return kTfLiteOk;
1210   };
1211   context->PreviewDelegatePartitioning =
1212       [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
1213          TfLiteDelegateParams** partition_params_array, int* num_partitions) {
1214         if (nodes_to_replace->size == 0) {
1215           *num_partitions = 0;
1216           return kTfLiteOk;
1217         }
1218         auto params = interpreter_quant->add_delegate_params();
1219         params->nodes_to_replace = TfLiteIntArrayCreate(3);
1220         params->nodes_to_replace->data[0] = 0;
1221         params->nodes_to_replace->data[1] = 1;
1222         params->nodes_to_replace->data[2] = 2;
1223         params->input_tensors = TfLiteIntArrayCreate(2);
1224         params->input_tensors->data[0] = 0;
1225         params->input_tensors->data[1] = 3;
1226         params->output_tensors = TfLiteIntArrayCreate(1);
1227         params->output_tensors->data[0] = 4;
1228 
1229         *partition_params_array = interpreter_quant->delegate_params();
1230         *num_partitions = interpreter_quant->num_delegate_params();
1231         return kTfLiteOk;
1232       };
1233 
1234   TfLiteIntArray* ops_to_replace =
1235       GetOpsToReplace(context, /**allow_quant_ops=*/true);
1236   // If we allow quant ops, two QUANTIZE & one ADD node should be accepted.
1237   EXPECT_EQ(ops_to_replace->size, 3);
1238   EXPECT_EQ(0, ops_to_replace->data[0]);
1239   EXPECT_EQ(1, ops_to_replace->data[1]);
1240   EXPECT_EQ(2, ops_to_replace->data[2]);
1241 
1242   TfLiteIntArray* ops_to_replace_without_quant =
1243       GetOpsToReplace(context, /**allow_quant_ops=*/false);
1244   // No ops should be accepted.
1245   EXPECT_EQ(ops_to_replace_without_quant->size, 0);
1246 
1247   TfLiteIntArrayFree(ops_to_replace);
1248   TfLiteIntArrayFree(ops_to_replace_without_quant);
1249 }
1250 
1251 }  // namespace
1252 }  // namespace gpu
1253 }  // namespace tflite
1254