1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/delegate_test_util.h"
17
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21
22 #include <memory>
23 #include <string>
24 #include <vector>
25
26 #include <gtest/gtest.h>
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/builtin_op_data.h"
30 #include "tensorflow/lite/delegates/utils.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
33 #include "tensorflow/lite/kernels/internal/compatibility.h"
34 #include "tensorflow/lite/kernels/kernel_util.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/string_type.h"
37 #include "tensorflow/lite/util.h"
38
39 namespace tflite {
40 namespace delegates {
41 namespace test_utils {
42
AddOpRegistration()43 TfLiteRegistration AddOpRegistration() {
44 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
45
46 reg.custom_name = "my_add";
47 reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
48
49 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
50 const TfLiteTensor* input1;
51 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
52 const TfLiteTensor* input2;
53 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input2));
54 TfLiteTensor* output;
55 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
56
57 // Verify that the two inputs have the same shape.
58 TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
59 for (int i = 0; i < input1->dims->size; ++i) {
60 TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
61 }
62
63 // Set output shape to match input shape.
64 TF_LITE_ENSURE_STATUS(context->ResizeTensor(
65 context, output, TfLiteIntArrayCopy(input1->dims)));
66 return kTfLiteOk;
67 };
68
69 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
70 const TfLiteTensor* a0;
71 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
72 TF_LITE_ENSURE(context, a0);
73 TF_LITE_ENSURE(context, a0->data.f);
74 const TfLiteTensor* a1;
75 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &a1));
76 TF_LITE_ENSURE(context, a1);
77 TF_LITE_ENSURE(context, a1->data.f);
78 TfLiteTensor* out;
79 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
80 TF_LITE_ENSURE(context, out);
81 TF_LITE_ENSURE(context, out->data.f);
82 // Set output data to element-wise sum of input data.
83 int num = a0->dims->data[0];
84 for (int i = 0; i < num; i++) {
85 out->data.f[i] = a0->data.f[i] + a1->data.f[i];
86 }
87 return kTfLiteOk;
88 };
89 return reg;
90 }
91
SetUp()92 void TestDelegate::SetUp() {
93 interpreter_.reset(new Interpreter);
94 SetUpSubgraph(&interpreter_->primary_subgraph());
95 }
96
SetUpSubgraph(Subgraph * subgraph)97 void TestDelegate::SetUpSubgraph(Subgraph* subgraph) {
98 subgraph->AddTensors(5);
99 subgraph->SetInputs({0, 1});
100 subgraph->SetOutputs({3, 4});
101 std::vector<int> dims({3});
102 TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
103 subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
104 dims.data(), quant, false);
105 subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
106 dims.data(), quant, false);
107 subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
108 dims.data(), quant, false);
109 subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
110 dims.data(), quant, false);
111 subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
112 dims.data(), quant, false);
113 TfLiteRegistration reg = AddOpRegistration();
114 int node_index_ignored;
115 subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, ®,
116 &node_index_ignored);
117 subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, ®,
118 &node_index_ignored);
119 subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, ®,
120 &node_index_ignored);
121 }
122
TearDown()123 void TestDelegate::TearDown() {
124 // Interpreter relies on delegate to free the resources properly. Thus
125 // the life cycle of delegate must be longer than interpreter.
126 interpreter_.reset();
127 delegate_.reset();
128 }
129
SimpleDelegate(const std::vector<int> & nodes,int64_t delegate_flags,bool fail_node_prepare,int min_ops_per_subset,bool fail_node_invoke,bool automatic_shape_propagation,bool custom_op)130 TestDelegate::SimpleDelegate::SimpleDelegate(
131 const std::vector<int>& nodes, int64_t delegate_flags,
132 bool fail_node_prepare, int min_ops_per_subset, bool fail_node_invoke,
133 bool automatic_shape_propagation, bool custom_op)
134 : nodes_(nodes),
135 fail_delegate_node_prepare_(fail_node_prepare),
136 min_ops_per_subset_(min_ops_per_subset),
137 fail_delegate_node_invoke_(fail_node_invoke),
138 automatic_shape_propagation_(automatic_shape_propagation),
139 custom_op_(custom_op) {
140 delegate_.Prepare = [](TfLiteContext* context,
141 TfLiteDelegate* delegate) -> TfLiteStatus {
142 auto* simple = static_cast<SimpleDelegate*>(delegate->data_);
143 TfLiteIntArray* nodes_to_separate =
144 TfLiteIntArrayCreate(simple->nodes_.size());
145 // Mark nodes that we want in TfLiteIntArray* structure.
146 int index = 0;
147 for (auto node_index : simple->nodes_) {
148 nodes_to_separate->data[index++] = node_index;
149 // make sure node is added
150 TfLiteNode* node;
151 TfLiteRegistration* reg;
152 context->GetNodeAndRegistration(context, node_index, &node, ®);
153 if (simple->custom_op_) {
154 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
155 TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
156 } else {
157 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
158 }
159 }
160 // Check that all nodes are available
161 TfLiteIntArray* execution_plan;
162 TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
163 for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
164 int node_index = execution_plan->data[exec_index];
165 TfLiteNode* node;
166 TfLiteRegistration* reg;
167 context->GetNodeAndRegistration(context, node_index, &node, ®);
168 if (exec_index == node_index) {
169 // Check op details only if it wasn't delegated already.
170 if (simple->custom_op_) {
171 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
172 TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
173 } else {
174 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
175 }
176 }
177 }
178
179 // Get preview of delegate partitioning from the context.
180 TfLiteDelegateParams* params_array;
181 int num_partitions;
182 TFLITE_CHECK_EQ(
183 context->PreviewDelegatePartitioning(context, nodes_to_separate,
184 ¶ms_array, &num_partitions),
185 kTfLiteOk);
186
187 if (simple->min_ops_per_subset() > 0) {
188 // Build a new vector of ops from subsets with at least the minimum
189 // size.
190 std::vector<int> allowed_ops;
191 for (int idx = 0; idx < num_partitions; ++idx) {
192 const auto* nodes_in_subset = params_array[idx].nodes_to_replace;
193 if (nodes_in_subset->size < simple->min_ops_per_subset()) continue;
194 allowed_ops.insert(allowed_ops.end(), nodes_in_subset->data,
195 nodes_in_subset->data + nodes_in_subset->size);
196 }
197
198 // Free existing nodes_to_separate & initialize a new array with
199 // allowed_ops.
200 TfLiteIntArrayFree(nodes_to_separate);
201 nodes_to_separate = TfLiteIntArrayCreate(allowed_ops.size());
202 memcpy(nodes_to_separate->data, allowed_ops.data(),
203 sizeof(int) * nodes_to_separate->size);
204 }
205
206 // Another call to PreviewDelegatePartitioning should be okay, since
207 // partitioning memory is managed by context.
208 TFLITE_CHECK_EQ(
209 context->PreviewDelegatePartitioning(context, nodes_to_separate,
210 ¶ms_array, &num_partitions),
211 kTfLiteOk);
212
213 context->ReplaceNodeSubsetsWithDelegateKernels(
214 context, simple->FakeFusedRegistration(), nodes_to_separate, delegate);
215 TfLiteIntArrayFree(nodes_to_separate);
216 return kTfLiteOk;
217 };
218 delegate_.CopyToBufferHandle = [](TfLiteContext* context,
219 TfLiteDelegate* delegate,
220 TfLiteBufferHandle buffer_handle,
221 TfLiteTensor* tensor) -> TfLiteStatus {
222 // TODO(b/156586986): Implement tests to test buffer copying logic.
223 return kTfLiteOk;
224 };
225 delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
226 TfLiteDelegate* delegate,
227 TfLiteBufferHandle buffer_handle,
228 TfLiteTensor* output) -> TfLiteStatus {
229 TFLITE_CHECK_GE(buffer_handle, -1);
230 TFLITE_CHECK_EQ(output->buffer_handle, buffer_handle);
231 const float floats[] = {6., 6., 6.};
232 int num = output->dims->data[0];
233 for (int i = 0; i < num; i++) {
234 output->data.f[i] = floats[i];
235 }
236 return kTfLiteOk;
237 };
238
239 delegate_.FreeBufferHandle =
240 [](TfLiteContext* context, TfLiteDelegate* delegate,
241 TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; };
242 // Store type-punned data SimpleDelegate structure.
243 delegate_.data_ = static_cast<void*>(this);
244 delegate_.flags = delegate_flags;
245 }
246
FakeFusedRegistration()247 TfLiteRegistration TestDelegate::SimpleDelegate::FakeFusedRegistration() {
248 TfLiteRegistration reg = {nullptr};
249 reg.custom_name = "fake_fused_op";
250
251 // Different flavors of the delegate kernel's Invoke(), dependent on
252 // testing parameters.
253 if (fail_delegate_node_invoke_) {
254 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
255 return kTfLiteError;
256 };
257 } else {
258 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
259 // Copy input data to output data.
260 const TfLiteTensor* a0;
261 const TfLiteTensor* a1;
262 if (node->inputs->size == 2) {
263 a0 = GetInput(context, node, 0);
264 a1 = GetInput(context, node, 1);
265 } else {
266 a0 = GetInput(context, node, 0);
267 a1 = a0;
268 }
269 TfLiteTensor* out;
270 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
271 int num = 1;
272 for (int i = 0; i < a0->dims->size; ++i) {
273 num *= a0->dims->data[i];
274 }
275 for (int i = 0; i < num; i++) {
276 out->data.f[i] = a0->data.f[i] + a1->data.f[i];
277 }
278 if (out->buffer_handle != kTfLiteNullBufferHandle) {
279 // Make the data stale so that CopyFromBufferHandle can be invoked
280 out->data_is_stale = true;
281 }
282 return kTfLiteOk;
283 };
284 }
285
286 // Different flavors of the delegate kernel's Prepare(), dependent on
287 // testing parameters.
288 if (automatic_shape_propagation_) {
289 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
290 // Shapes should already by propagated by the runtime, just need to
291 // check.
292 const TfLiteTensor* input1;
293 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
294 TfLiteTensor* output;
295 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
296 const int input_dims_size = input1->dims->size;
297 TF_LITE_ENSURE(context, output->dims->size == input_dims_size);
298 for (int i = 0; i < input_dims_size; ++i) {
299 TF_LITE_ENSURE(context, output->dims->data[i] == input1->dims->data[i]);
300 }
301 return kTfLiteOk;
302 };
303 } else if (fail_delegate_node_prepare_) {
304 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
305 return kTfLiteError;
306 };
307 } else {
308 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
309 // Set output size to input size
310 const TfLiteTensor* input1;
311 const TfLiteTensor* input2;
312 if (node->inputs->size == 2) {
313 input1 = GetInput(context, node, 0);
314 input2 = GetInput(context, node, 1);
315 } else {
316 input1 = GetInput(context, node, 0);
317 input2 = input1;
318 }
319 TfLiteTensor* output;
320 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
321
322 TF_LITE_ENSURE_STATUS(context->ResizeTensor(
323 context, output, TfLiteIntArrayCopy(input1->dims)));
324 return kTfLiteOk;
325 };
326 }
327
328 return reg;
329 }
330
SetUp()331 void TestFP16Delegation::SetUp() {
332 interpreter_.reset(new Interpreter);
333 interpreter_->AddTensors(13);
334 interpreter_->SetInputs({0});
335 interpreter_->SetOutputs({12});
336
337 float16_const_ = Eigen::half_impl::float_to_half_rtne(2.f);
338
339 // TENSORS.
340 TfLiteQuantizationParams quant;
341 // Input.
342 interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {1}, quant);
343 // fp16 constant, dequantize output, Add0 output.
344 interpreter_->SetTensorParametersReadOnly(
345 1, kTfLiteFloat16, "", {1}, quant,
346 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
347 interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {1}, quant);
348 interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {1}, quant);
349 // fp16 constant, dequantize output, Add1 output.
350 interpreter_->SetTensorParametersReadOnly(
351 4, kTfLiteFloat16, "", {1}, quant,
352 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
353 interpreter_->SetTensorParametersReadWrite(5, kTfLiteFloat32, "", {1}, quant);
354 interpreter_->SetTensorParametersReadWrite(6, kTfLiteFloat32, "", {1}, quant);
355 // fp16 constant, dequantize output, Mul0 output.
356 interpreter_->SetTensorParametersReadOnly(
357 7, kTfLiteFloat16, "", {1}, quant,
358 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
359 interpreter_->SetTensorParametersReadWrite(8, kTfLiteFloat32, "", {1}, quant);
360 interpreter_->SetTensorParametersReadWrite(9, kTfLiteFloat32, "", {1}, quant);
361 // fp16 constant, dequantize output, Add2 output.
362 interpreter_->SetTensorParametersReadOnly(
363 10, kTfLiteFloat16, "", {1}, quant,
364 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
365 interpreter_->SetTensorParametersReadWrite(11, kTfLiteFloat32, "", {1},
366 quant);
367 interpreter_->SetTensorParametersReadWrite(12, kTfLiteFloat32, "", {1},
368 quant);
369
370 // NODES.
371 auto* add_reg = ops::builtin::Register_ADD();
372 auto* mul_reg = ops::builtin::Register_MUL();
373 auto* deq_reg = ops::builtin::Register_DEQUANTIZE();
374 add_reg->builtin_code = kTfLiteBuiltinAdd;
375 deq_reg->builtin_code = kTfLiteBuiltinDequantize;
376 mul_reg->builtin_code = kTfLiteBuiltinMul;
377 TfLiteAddParams* builtin_data0 =
378 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
379 TfLiteAddParams* builtin_data1 =
380 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
381 TfLiteMulParams* builtin_data2 =
382 reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
383 TfLiteAddParams* builtin_data3 =
384 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
385 builtin_data0->activation = kTfLiteActNone;
386 builtin_data1->activation = kTfLiteActNone;
387 builtin_data2->activation = kTfLiteActNone;
388 builtin_data3->activation = kTfLiteActNone;
389 interpreter_->AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, deq_reg);
390 interpreter_->AddNodeWithParameters({0, 2}, {3}, nullptr, 0, builtin_data0,
391 add_reg);
392 interpreter_->AddNodeWithParameters({4}, {5}, nullptr, 0, nullptr, deq_reg);
393 interpreter_->AddNodeWithParameters({3, 5}, {6}, nullptr, 0, builtin_data1,
394 add_reg);
395 interpreter_->AddNodeWithParameters({7}, {8}, nullptr, 0, nullptr, deq_reg);
396 interpreter_->AddNodeWithParameters({6, 8}, {9}, nullptr, 0, builtin_data2,
397 mul_reg);
398 interpreter_->AddNodeWithParameters({10}, {11}, nullptr, 0, nullptr, deq_reg);
399 interpreter_->AddNodeWithParameters({9, 11}, {12}, nullptr, 0, builtin_data3,
400 add_reg);
401 }
402
VerifyInvoke()403 void TestFP16Delegation::VerifyInvoke() {
404 std::vector<float> input = {3.0f};
405 std::vector<float> expected_output = {16.0f};
406
407 const int input_tensor_idx = interpreter_->inputs()[0];
408 const int output_tensor_idx = interpreter_->outputs()[0];
409
410 memcpy(interpreter_->typed_tensor<float>(input_tensor_idx), input.data(),
411 sizeof(float));
412 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
413 TfLiteTensor* output_tensor = interpreter_->tensor(output_tensor_idx);
414 for (int i = 0; i < 1; ++i) {
415 EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i;
416 }
417 }
418
FP16Delegate(int num_delegated_subsets,bool fail_node_prepare,bool fail_node_invoke)419 TestFP16Delegation::FP16Delegate::FP16Delegate(int num_delegated_subsets,
420 bool fail_node_prepare,
421 bool fail_node_invoke)
422 : num_delegated_subsets_(num_delegated_subsets),
423 fail_delegate_node_prepare_(fail_node_prepare),
424 fail_delegate_node_invoke_(fail_node_invoke) {
425 delegate_.Prepare = [](TfLiteContext* context,
426 TfLiteDelegate* delegate) -> TfLiteStatus {
427 auto* fp16_delegate = static_cast<FP16Delegate*>(delegate->data_);
428 // FP16 graph partitioning.
429 delegates::IsNodeSupportedFn node_supported_fn =
430 [=](TfLiteContext* context, TfLiteNode* node,
431 TfLiteRegistration* registration,
432 std::string* unsupported_details) -> bool {
433 return registration->builtin_code == kTfLiteBuiltinAdd;
434 };
435 delegates::FP16GraphPartitionHelper partition_helper(context,
436 node_supported_fn);
437 TfLiteIntArray* nodes_to_separate = nullptr;
438 if (partition_helper.Partition(nullptr) != kTfLiteOk) {
439 nodes_to_separate = TfLiteIntArrayCreate(0);
440 } else {
441 std::vector<int> ops_to_replace =
442 partition_helper.GetNodesOfFirstNLargestPartitions(
443 fp16_delegate->num_delegated_subsets());
444 nodes_to_separate = ConvertVectorToTfLiteIntArray(ops_to_replace);
445 }
446
447 context->ReplaceNodeSubsetsWithDelegateKernels(
448 context, fp16_delegate->FakeFusedRegistration(), nodes_to_separate,
449 delegate);
450 TfLiteIntArrayFree(nodes_to_separate);
451 return kTfLiteOk;
452 };
453 delegate_.CopyFromBufferHandle =
454 [](TfLiteContext* context, TfLiteDelegate* delegate,
455 TfLiteBufferHandle buffer_handle,
456 TfLiteTensor* output) -> TfLiteStatus { return kTfLiteOk; };
457 delegate_.FreeBufferHandle = nullptr;
458 delegate_.CopyToBufferHandle = nullptr;
459 // Store type-punned data SimpleDelegate structure.
460 delegate_.data_ = static_cast<void*>(this);
461 delegate_.flags = kTfLiteDelegateFlagsNone;
462 }
463
FakeFusedRegistration()464 TfLiteRegistration TestFP16Delegation::FP16Delegate::FakeFusedRegistration() {
465 TfLiteRegistration reg = {nullptr};
466 reg.custom_name = "fake_fp16_add_op";
467
468 // Different flavors of the delegate kernel's Invoke(), dependent on
469 // testing parameters.
470 if (fail_delegate_node_invoke_) {
471 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
472 return kTfLiteError;
473 };
474 } else {
475 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
476 float output = 0;
477 for (int i = 0; i < node->inputs->size; ++i) {
478 const TfLiteTensor* input_tensor = GetInput(context, node, i);
479 if (input_tensor->type == kTfLiteFloat32) {
480 output += input_tensor->data.f[0];
481 } else {
482 // All constants are 2.
483 output += 2;
484 }
485 }
486 TfLiteTensor* out = GetOutput(context, node, 0);
487 out->data.f[0] = output;
488 return kTfLiteOk;
489 };
490 }
491
492 // Different flavors of the delegate kernel's Prepare(), dependent on
493 // testing parameters.
494 if (fail_delegate_node_prepare_) {
495 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
496 return kTfLiteError;
497 };
498 } else {
499 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
500 // Set output size to input size
501 const TfLiteTensor* input = GetInput(context, node, 0);
502 TfLiteTensor* output = GetOutput(context, node, 0);
503 TF_LITE_ENSURE_STATUS(context->ResizeTensor(
504 context, output, TfLiteIntArrayCopy(input->dims)));
505 return kTfLiteOk;
506 };
507 }
508
509 return reg;
510 }
511
512 } // namespace test_utils
513 } // namespace delegates
514 } // namespace tflite
515