• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/delegate_test_util.h"
17 
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include <gtest/gtest.h>
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/builtin_op_data.h"
30 #include "tensorflow/lite/delegates/utils.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
33 #include "tensorflow/lite/kernels/internal/compatibility.h"
34 #include "tensorflow/lite/kernels/kernel_util.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/string_type.h"
37 #include "tensorflow/lite/util.h"
38 
39 namespace tflite {
40 namespace delegates {
41 namespace test_utils {
42 
AddOpRegistration()43 TfLiteRegistration AddOpRegistration() {
44   TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
45 
46   reg.custom_name = "my_add";
47   reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
48 
49   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
50     const TfLiteTensor* input1;
51     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
52     const TfLiteTensor* input2;
53     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input2));
54     TfLiteTensor* output;
55     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
56 
57     // Verify that the two inputs have the same shape.
58     TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
59     for (int i = 0; i < input1->dims->size; ++i) {
60       TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
61     }
62 
63     // Set output shape to match input shape.
64     TF_LITE_ENSURE_STATUS(context->ResizeTensor(
65         context, output, TfLiteIntArrayCopy(input1->dims)));
66     return kTfLiteOk;
67   };
68 
69   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
70     const TfLiteTensor* a0;
71     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
72     TF_LITE_ENSURE(context, a0);
73     TF_LITE_ENSURE(context, a0->data.f);
74     const TfLiteTensor* a1;
75     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &a1));
76     TF_LITE_ENSURE(context, a1);
77     TF_LITE_ENSURE(context, a1->data.f);
78     TfLiteTensor* out;
79     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
80     TF_LITE_ENSURE(context, out);
81     TF_LITE_ENSURE(context, out->data.f);
82     // Set output data to element-wise sum of input data.
83     int num = a0->dims->data[0];
84     for (int i = 0; i < num; i++) {
85       out->data.f[i] = a0->data.f[i] + a1->data.f[i];
86     }
87     return kTfLiteOk;
88   };
89   return reg;
90 }
91 
SetUp()92 void TestDelegate::SetUp() {
93   interpreter_.reset(new Interpreter);
94   SetUpSubgraph(&interpreter_->primary_subgraph());
95 }
96 
SetUpSubgraph(Subgraph * subgraph)97 void TestDelegate::SetUpSubgraph(Subgraph* subgraph) {
98   subgraph->AddTensors(5);
99   subgraph->SetInputs({0, 1});
100   subgraph->SetOutputs({3, 4});
101   std::vector<int> dims({3});
102   TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
103   subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
104                                          dims.data(), quant, false);
105   subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
106                                          dims.data(), quant, false);
107   subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
108                                          dims.data(), quant, false);
109   subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
110                                          dims.data(), quant, false);
111   subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
112                                          dims.data(), quant, false);
113   TfLiteRegistration reg = AddOpRegistration();
114   int node_index_ignored;
115   subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, &reg,
116                                   &node_index_ignored);
117   subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, &reg,
118                                   &node_index_ignored);
119   subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, &reg,
120                                   &node_index_ignored);
121 }
122 
TearDown()123 void TestDelegate::TearDown() {
124   // Interpreter relies on delegate to free the resources properly. Thus
125   // the life cycle of delegate must be longer than interpreter.
126   interpreter_.reset();
127   delegate_.reset();
128 }
129 
SimpleDelegate(const std::vector<int> & nodes,int64_t delegate_flags,bool fail_node_prepare,int min_ops_per_subset,bool fail_node_invoke,bool automatic_shape_propagation,bool custom_op)130 TestDelegate::SimpleDelegate::SimpleDelegate(
131     const std::vector<int>& nodes, int64_t delegate_flags,
132     bool fail_node_prepare, int min_ops_per_subset, bool fail_node_invoke,
133     bool automatic_shape_propagation, bool custom_op)
134     : nodes_(nodes),
135       fail_delegate_node_prepare_(fail_node_prepare),
136       min_ops_per_subset_(min_ops_per_subset),
137       fail_delegate_node_invoke_(fail_node_invoke),
138       automatic_shape_propagation_(automatic_shape_propagation),
139       custom_op_(custom_op) {
140   delegate_.Prepare = [](TfLiteContext* context,
141                          TfLiteDelegate* delegate) -> TfLiteStatus {
142     auto* simple = static_cast<SimpleDelegate*>(delegate->data_);
143     TfLiteIntArray* nodes_to_separate =
144         TfLiteIntArrayCreate(simple->nodes_.size());
145     // Mark nodes that we want in TfLiteIntArray* structure.
146     int index = 0;
147     for (auto node_index : simple->nodes_) {
148       nodes_to_separate->data[index++] = node_index;
149       // make sure node is added
150       TfLiteNode* node;
151       TfLiteRegistration* reg;
152       context->GetNodeAndRegistration(context, node_index, &node, &reg);
153       if (simple->custom_op_) {
154         TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
155         TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
156       } else {
157         TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
158       }
159     }
160     // Check that all nodes are available
161     TfLiteIntArray* execution_plan;
162     TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
163     for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
164       int node_index = execution_plan->data[exec_index];
165       TfLiteNode* node;
166       TfLiteRegistration* reg;
167       context->GetNodeAndRegistration(context, node_index, &node, &reg);
168       if (exec_index == node_index) {
169         // Check op details only if it wasn't delegated already.
170         if (simple->custom_op_) {
171           TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
172           TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
173         } else {
174           TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
175         }
176       }
177     }
178 
179     // Get preview of delegate partitioning from the context.
180     TfLiteDelegateParams* params_array;
181     int num_partitions;
182     TFLITE_CHECK_EQ(
183         context->PreviewDelegatePartitioning(context, nodes_to_separate,
184                                              &params_array, &num_partitions),
185         kTfLiteOk);
186 
187     if (simple->min_ops_per_subset() > 0) {
188       // Build a new vector of ops from subsets with at least the minimum
189       // size.
190       std::vector<int> allowed_ops;
191       for (int idx = 0; idx < num_partitions; ++idx) {
192         const auto* nodes_in_subset = params_array[idx].nodes_to_replace;
193         if (nodes_in_subset->size < simple->min_ops_per_subset()) continue;
194         allowed_ops.insert(allowed_ops.end(), nodes_in_subset->data,
195                            nodes_in_subset->data + nodes_in_subset->size);
196       }
197 
198       // Free existing nodes_to_separate & initialize a new array with
199       // allowed_ops.
200       TfLiteIntArrayFree(nodes_to_separate);
201       nodes_to_separate = TfLiteIntArrayCreate(allowed_ops.size());
202       memcpy(nodes_to_separate->data, allowed_ops.data(),
203              sizeof(int) * nodes_to_separate->size);
204     }
205 
206     // Another call to PreviewDelegatePartitioning should be okay, since
207     // partitioning memory is managed by context.
208     TFLITE_CHECK_EQ(
209         context->PreviewDelegatePartitioning(context, nodes_to_separate,
210                                              &params_array, &num_partitions),
211         kTfLiteOk);
212 
213     context->ReplaceNodeSubsetsWithDelegateKernels(
214         context, simple->FakeFusedRegistration(), nodes_to_separate, delegate);
215     TfLiteIntArrayFree(nodes_to_separate);
216     return kTfLiteOk;
217   };
218   delegate_.CopyToBufferHandle = [](TfLiteContext* context,
219                                     TfLiteDelegate* delegate,
220                                     TfLiteBufferHandle buffer_handle,
221                                     TfLiteTensor* tensor) -> TfLiteStatus {
222     // TODO(b/156586986): Implement tests to test buffer copying logic.
223     return kTfLiteOk;
224   };
225   delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
226                                       TfLiteDelegate* delegate,
227                                       TfLiteBufferHandle buffer_handle,
228                                       TfLiteTensor* output) -> TfLiteStatus {
229     TFLITE_CHECK_GE(buffer_handle, -1);
230     TFLITE_CHECK_EQ(output->buffer_handle, buffer_handle);
231     const float floats[] = {6., 6., 6.};
232     int num = output->dims->data[0];
233     for (int i = 0; i < num; i++) {
234       output->data.f[i] = floats[i];
235     }
236     return kTfLiteOk;
237   };
238 
239   delegate_.FreeBufferHandle =
240       [](TfLiteContext* context, TfLiteDelegate* delegate,
241          TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; };
242   // Store type-punned data SimpleDelegate structure.
243   delegate_.data_ = static_cast<void*>(this);
244   delegate_.flags = delegate_flags;
245 }
246 
FakeFusedRegistration()247 TfLiteRegistration TestDelegate::SimpleDelegate::FakeFusedRegistration() {
248   TfLiteRegistration reg = {nullptr};
249   reg.custom_name = "fake_fused_op";
250 
251   // Different flavors of the delegate kernel's Invoke(), dependent on
252   // testing parameters.
253   if (fail_delegate_node_invoke_) {
254     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
255       return kTfLiteError;
256     };
257   } else {
258     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
259       // Copy input data to output data.
260       const TfLiteTensor* a0;
261       const TfLiteTensor* a1;
262       if (node->inputs->size == 2) {
263         a0 = GetInput(context, node, 0);
264         a1 = GetInput(context, node, 1);
265       } else {
266         a0 = GetInput(context, node, 0);
267         a1 = a0;
268       }
269       TfLiteTensor* out;
270       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
271       int num = 1;
272       for (int i = 0; i < a0->dims->size; ++i) {
273         num *= a0->dims->data[i];
274       }
275       for (int i = 0; i < num; i++) {
276         out->data.f[i] = a0->data.f[i] + a1->data.f[i];
277       }
278       if (out->buffer_handle != kTfLiteNullBufferHandle) {
279         // Make the data stale so that CopyFromBufferHandle can be invoked
280         out->data_is_stale = true;
281       }
282       return kTfLiteOk;
283     };
284   }
285 
286   // Different flavors of the delegate kernel's Prepare(), dependent on
287   // testing parameters.
288   if (automatic_shape_propagation_) {
289     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
290       // Shapes should already by propagated by the runtime, just need to
291       // check.
292       const TfLiteTensor* input1;
293       TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
294       TfLiteTensor* output;
295       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
296       const int input_dims_size = input1->dims->size;
297       TF_LITE_ENSURE(context, output->dims->size == input_dims_size);
298       for (int i = 0; i < input_dims_size; ++i) {
299         TF_LITE_ENSURE(context, output->dims->data[i] == input1->dims->data[i]);
300       }
301       return kTfLiteOk;
302     };
303   } else if (fail_delegate_node_prepare_) {
304     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
305       return kTfLiteError;
306     };
307   } else {
308     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
309       // Set output size to input size
310       const TfLiteTensor* input1;
311       const TfLiteTensor* input2;
312       if (node->inputs->size == 2) {
313         input1 = GetInput(context, node, 0);
314         input2 = GetInput(context, node, 1);
315       } else {
316         input1 = GetInput(context, node, 0);
317         input2 = input1;
318       }
319       TfLiteTensor* output;
320       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
321 
322       TF_LITE_ENSURE_STATUS(context->ResizeTensor(
323           context, output, TfLiteIntArrayCopy(input1->dims)));
324       return kTfLiteOk;
325     };
326   }
327 
328   return reg;
329 }
330 
SetUp()331 void TestFP16Delegation::SetUp() {
332   interpreter_.reset(new Interpreter);
333   interpreter_->AddTensors(13);
334   interpreter_->SetInputs({0});
335   interpreter_->SetOutputs({12});
336 
337   float16_const_ = Eigen::half_impl::float_to_half_rtne(2.f);
338 
339   // TENSORS.
340   TfLiteQuantizationParams quant;
341   // Input.
342   interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {1}, quant);
343   // fp16 constant, dequantize output, Add0 output.
344   interpreter_->SetTensorParametersReadOnly(
345       1, kTfLiteFloat16, "", {1}, quant,
346       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
347   interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {1}, quant);
348   interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {1}, quant);
349   // fp16 constant, dequantize output, Add1 output.
350   interpreter_->SetTensorParametersReadOnly(
351       4, kTfLiteFloat16, "", {1}, quant,
352       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
353   interpreter_->SetTensorParametersReadWrite(5, kTfLiteFloat32, "", {1}, quant);
354   interpreter_->SetTensorParametersReadWrite(6, kTfLiteFloat32, "", {1}, quant);
355   // fp16 constant, dequantize output, Mul0 output.
356   interpreter_->SetTensorParametersReadOnly(
357       7, kTfLiteFloat16, "", {1}, quant,
358       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
359   interpreter_->SetTensorParametersReadWrite(8, kTfLiteFloat32, "", {1}, quant);
360   interpreter_->SetTensorParametersReadWrite(9, kTfLiteFloat32, "", {1}, quant);
361   // fp16 constant, dequantize output, Add2 output.
362   interpreter_->SetTensorParametersReadOnly(
363       10, kTfLiteFloat16, "", {1}, quant,
364       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
365   interpreter_->SetTensorParametersReadWrite(11, kTfLiteFloat32, "", {1},
366                                              quant);
367   interpreter_->SetTensorParametersReadWrite(12, kTfLiteFloat32, "", {1},
368                                              quant);
369 
370   // NODES.
371   auto* add_reg = ops::builtin::Register_ADD();
372   auto* mul_reg = ops::builtin::Register_MUL();
373   auto* deq_reg = ops::builtin::Register_DEQUANTIZE();
374   add_reg->builtin_code = kTfLiteBuiltinAdd;
375   deq_reg->builtin_code = kTfLiteBuiltinDequantize;
376   mul_reg->builtin_code = kTfLiteBuiltinMul;
377   TfLiteAddParams* builtin_data0 =
378       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
379   TfLiteAddParams* builtin_data1 =
380       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
381   TfLiteMulParams* builtin_data2 =
382       reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
383   TfLiteAddParams* builtin_data3 =
384       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
385   builtin_data0->activation = kTfLiteActNone;
386   builtin_data1->activation = kTfLiteActNone;
387   builtin_data2->activation = kTfLiteActNone;
388   builtin_data3->activation = kTfLiteActNone;
389   interpreter_->AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, deq_reg);
390   interpreter_->AddNodeWithParameters({0, 2}, {3}, nullptr, 0, builtin_data0,
391                                       add_reg);
392   interpreter_->AddNodeWithParameters({4}, {5}, nullptr, 0, nullptr, deq_reg);
393   interpreter_->AddNodeWithParameters({3, 5}, {6}, nullptr, 0, builtin_data1,
394                                       add_reg);
395   interpreter_->AddNodeWithParameters({7}, {8}, nullptr, 0, nullptr, deq_reg);
396   interpreter_->AddNodeWithParameters({6, 8}, {9}, nullptr, 0, builtin_data2,
397                                       mul_reg);
398   interpreter_->AddNodeWithParameters({10}, {11}, nullptr, 0, nullptr, deq_reg);
399   interpreter_->AddNodeWithParameters({9, 11}, {12}, nullptr, 0, builtin_data3,
400                                       add_reg);
401 }
402 
VerifyInvoke()403 void TestFP16Delegation::VerifyInvoke() {
404   std::vector<float> input = {3.0f};
405   std::vector<float> expected_output = {16.0f};
406 
407   const int input_tensor_idx = interpreter_->inputs()[0];
408   const int output_tensor_idx = interpreter_->outputs()[0];
409 
410   memcpy(interpreter_->typed_tensor<float>(input_tensor_idx), input.data(),
411          sizeof(float));
412   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
413   TfLiteTensor* output_tensor = interpreter_->tensor(output_tensor_idx);
414   for (int i = 0; i < 1; ++i) {
415     EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i;
416   }
417 }
418 
FP16Delegate(int num_delegated_subsets,bool fail_node_prepare,bool fail_node_invoke)419 TestFP16Delegation::FP16Delegate::FP16Delegate(int num_delegated_subsets,
420                                                bool fail_node_prepare,
421                                                bool fail_node_invoke)
422     : num_delegated_subsets_(num_delegated_subsets),
423       fail_delegate_node_prepare_(fail_node_prepare),
424       fail_delegate_node_invoke_(fail_node_invoke) {
425   delegate_.Prepare = [](TfLiteContext* context,
426                          TfLiteDelegate* delegate) -> TfLiteStatus {
427     auto* fp16_delegate = static_cast<FP16Delegate*>(delegate->data_);
428     // FP16 graph partitioning.
429     delegates::IsNodeSupportedFn node_supported_fn =
430         [=](TfLiteContext* context, TfLiteNode* node,
431             TfLiteRegistration* registration,
432             std::string* unsupported_details) -> bool {
433       return registration->builtin_code == kTfLiteBuiltinAdd;
434     };
435     delegates::FP16GraphPartitionHelper partition_helper(context,
436                                                          node_supported_fn);
437     TfLiteIntArray* nodes_to_separate = nullptr;
438     if (partition_helper.Partition(nullptr) != kTfLiteOk) {
439       nodes_to_separate = TfLiteIntArrayCreate(0);
440     } else {
441       std::vector<int> ops_to_replace =
442           partition_helper.GetNodesOfFirstNLargestPartitions(
443               fp16_delegate->num_delegated_subsets());
444       nodes_to_separate = ConvertVectorToTfLiteIntArray(ops_to_replace);
445     }
446 
447     context->ReplaceNodeSubsetsWithDelegateKernels(
448         context, fp16_delegate->FakeFusedRegistration(), nodes_to_separate,
449         delegate);
450     TfLiteIntArrayFree(nodes_to_separate);
451     return kTfLiteOk;
452   };
453   delegate_.CopyFromBufferHandle =
454       [](TfLiteContext* context, TfLiteDelegate* delegate,
455          TfLiteBufferHandle buffer_handle,
456          TfLiteTensor* output) -> TfLiteStatus { return kTfLiteOk; };
457   delegate_.FreeBufferHandle = nullptr;
458   delegate_.CopyToBufferHandle = nullptr;
459   // Store type-punned data SimpleDelegate structure.
460   delegate_.data_ = static_cast<void*>(this);
461   delegate_.flags = kTfLiteDelegateFlagsNone;
462 }
463 
FakeFusedRegistration()464 TfLiteRegistration TestFP16Delegation::FP16Delegate::FakeFusedRegistration() {
465   TfLiteRegistration reg = {nullptr};
466   reg.custom_name = "fake_fp16_add_op";
467 
468   // Different flavors of the delegate kernel's Invoke(), dependent on
469   // testing parameters.
470   if (fail_delegate_node_invoke_) {
471     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
472       return kTfLiteError;
473     };
474   } else {
475     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
476       float output = 0;
477       for (int i = 0; i < node->inputs->size; ++i) {
478         const TfLiteTensor* input_tensor = GetInput(context, node, i);
479         if (input_tensor->type == kTfLiteFloat32) {
480           output += input_tensor->data.f[0];
481         } else {
482           // All constants are 2.
483           output += 2;
484         }
485       }
486       TfLiteTensor* out = GetOutput(context, node, 0);
487       out->data.f[0] = output;
488       return kTfLiteOk;
489     };
490   }
491 
492   // Different flavors of the delegate kernel's Prepare(), dependent on
493   // testing parameters.
494   if (fail_delegate_node_prepare_) {
495     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
496       return kTfLiteError;
497     };
498   } else {
499     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
500       // Set output size to input size
501       const TfLiteTensor* input = GetInput(context, node, 0);
502       TfLiteTensor* output = GetOutput(context, node, 0);
503       TF_LITE_ENSURE_STATUS(context->ResizeTensor(
504           context, output, TfLiteIntArrayCopy(input->dims)));
505       return kTfLiteOk;
506     };
507   }
508 
509   return reg;
510 }
511 
512 }  // namespace test_utils
513 }  // namespace delegates
514 }  // namespace tflite
515