1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
16
17 #include <stddef.h>
18
19 #include <cstdint>
20 #include <limits>
21
22 #include "ruy/profiler/instrumentation.h" // from @ruy
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/cpu_backend_context.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30 #include "tensorflow/lite/kernels/internal/quantization_util.h"
31 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
32 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
33 #include "tensorflow/lite/kernels/internal/tensor.h"
34 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36 #include "tensorflow/lite/kernels/kernel_util.h"
37
38 namespace tflite {
39 namespace ops {
40 namespace builtin {
41 namespace reduce {
42
43 // This file has reference implementation of reduce_* operators.
44 enum KernelType {
45 kReference,
46 kGenericOptimized,
47 };
48
49 struct OpData {
50 int32_t multiplier;
51 int shift;
52 // The index of the temporary tensor where the quantized inputs are cached.
53 int scratch_tensor_index;
54 };
55
56 struct OpContext {
OpContexttflite::ops::builtin::reduce::OpContext57 OpContext(TfLiteContext* context, TfLiteNode* node) {
58 params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
59 input = GetInput(context, node, 0);
60 axis = GetInput(context, node, 1);
61 output = GetOutput(context, node, 0);
62 }
63 TfLiteReducerParams* params;
64 const TfLiteTensor* input;
65 const TfLiteTensor* axis;
66 TfLiteTensor* output;
67 };
68
Init(TfLiteContext * context,const char * buffer,size_t length)69 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
70 // Creates two temp tensors to store index and axis for internal
71 // implementation only.
72 auto* op_data = new OpData();
73 context->AddTensors(context, 3, &op_data->scratch_tensor_index);
74 return op_data;
75 }
76
Free(TfLiteContext * context,void * buffer)77 void Free(TfLiteContext* context, void* buffer) {
78 delete reinterpret_cast<OpData*>(buffer);
79 }
80
81 // Resizes the temp tensor that stores resolved axis.
ResizeTempAxis(TfLiteContext * context,OpContext * op_context,TfLiteTensor * resolved_axis)82 TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context,
83 TfLiteTensor* resolved_axis) {
84 TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1);
85 axis_size->data[0] = static_cast<int>(NumElements(op_context->axis));
86 return context->ResizeTensor(context, resolved_axis, axis_size);
87 }
88
89 // Resizes the temp tensor that stores temp sum of reduced elements.
ResizeTempSum(TfLiteContext * context,OpContext * op_context,TfLiteTensor * temp_sum)90 TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context,
91 TfLiteTensor* temp_sum) {
92 TfLiteIntArray* size = TfLiteIntArrayCreate(1);
93 size->data[0] = static_cast<int>(NumElements(op_context->output));
94 return context->ResizeTensor(context, temp_sum, size);
95 }
96
97 // Resizes output array based on the input size and resolved axis.
ResizeOutputTensor(TfLiteContext * context,OpContext * op_context)98 TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
99 size_t num_axis = NumElements(op_context->axis);
100 const TfLiteIntArray* input_dims = op_context->input->dims;
101 int input_num_dims = NumDimensions(op_context->input);
102 if (input_num_dims == 0) {
103 return context->ResizeTensor(context, op_context->output,
104 TfLiteIntArrayCreate(0));
105 }
106 const int* axis = GetTensorData<int>(op_context->axis);
107 if (op_context->params->keep_dims) {
108 TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
109 for (int idx = 0; idx < input_num_dims; ++idx) {
110 bool is_axis = false;
111 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
112 if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
113 is_axis = true;
114 break;
115 }
116 }
117 if (is_axis) {
118 output_dims->data[idx] = 1;
119 } else {
120 output_dims->data[idx] = input_dims->data[idx];
121 }
122 }
123 return context->ResizeTensor(context, op_context->output, output_dims);
124 } else {
125 // Calculates size of reducing axis.
126 int num_reduce_axis = num_axis;
127 for (int i = 0; i < num_axis; ++i) {
128 int current = axis[i];
129 if (current < 0) {
130 current += input_num_dims;
131 }
132 TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims);
133 for (int j = 0; j < i; ++j) {
134 int previous = axis[j];
135 if (previous < 0) {
136 previous += input_num_dims;
137 }
138 if (current == previous) {
139 --num_reduce_axis;
140 break;
141 }
142 }
143 }
144 // Determines output dimensions.
145 TfLiteIntArray* output_dims =
146 TfLiteIntArrayCreate(input_num_dims - num_reduce_axis);
147 int num_skip_axis = 0;
148 for (int idx = 0; idx < input_num_dims; ++idx) {
149 bool is_axis = false;
150 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
151 if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
152 ++num_skip_axis;
153 is_axis = true;
154 break;
155 }
156 }
157 if (!is_axis) {
158 output_dims->data[idx - num_skip_axis] = input_dims->data[idx];
159 }
160 }
161 return context->ResizeTensor(context, op_context->output, output_dims);
162 }
163 }
164
165 // Initializes temp tensors to store index and resolved axis.
InitializeTemporaries(TfLiteContext * context,TfLiteNode * node,OpContext * op_context)166 TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
167 OpContext* op_context) {
168 // Creates a temp index to iterate through input data.
169 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
170 TfLiteIntArrayFree(node->temporaries);
171 node->temporaries = TfLiteIntArrayCreate(3);
172 node->temporaries->data[0] = op_data->scratch_tensor_index;
173 TfLiteTensor* scratch_tensor;
174 TF_LITE_ENSURE_OK(
175 context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
176 scratch_tensor->type = kTfLiteInt32;
177 scratch_tensor->allocation_type = kTfLiteArenaRw;
178 TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
179 index_size->data[0] = NumDimensions(op_context->input);
180 TF_LITE_ENSURE_OK(context,
181 context->ResizeTensor(context, scratch_tensor, index_size));
182
183 // Creates a temp tensor to store resolved axis given input data.
184 node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
185 TfLiteTensor* resolved_axis;
186 TF_LITE_ENSURE_OK(
187 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
188 resolved_axis->type = kTfLiteInt32;
189 // Creates a temp tensor to store temp sums when calculating mean.
190 node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
191 TfLiteTensor* temp_sum;
192 TF_LITE_ENSURE_OK(context,
193 GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
194 switch (op_context->input->type) {
195 case kTfLiteFloat32:
196 temp_sum->type = kTfLiteFloat32;
197 break;
198 case kTfLiteInt32:
199 temp_sum->type = kTfLiteInt64;
200 break;
201 case kTfLiteInt64:
202 temp_sum->type = kTfLiteInt64;
203 break;
204 case kTfLiteUInt8:
205 case kTfLiteInt8:
206 case kTfLiteInt16:
207 temp_sum->type = kTfLiteInt32;
208 break;
209 case kTfLiteBool:
210 temp_sum->type = kTfLiteBool;
211 break;
212 default:
213 return kTfLiteError;
214 }
215 return kTfLiteOk;
216 }
217
PrepareSimple(TfLiteContext * context,TfLiteNode * node)218 TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
219 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
220 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
221
222 OpContext op_context(context, node);
223 TF_LITE_ENSURE_TYPES_EQ(context, op_context.axis->type, kTfLiteInt32);
224 TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
225
226 if (op_context.input->type == kTfLiteInt16) {
227 TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
228 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, 0);
229 }
230
231 TfLiteTensor* resolved_axis;
232 TF_LITE_ENSURE_OK(
233 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
234 // Leaves work to Eval if axis is not constant; else resizes output.
235 if (!IsConstantTensor(op_context.axis)) {
236 SetTensorToDynamic(op_context.output);
237 SetTensorToDynamic(resolved_axis);
238 return kTfLiteOk;
239 }
240 resolved_axis->allocation_type = kTfLiteArenaRw;
241 TF_LITE_ENSURE_OK(context,
242 ResizeTempAxis(context, &op_context, resolved_axis));
243 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
244 return kTfLiteOk;
245 }
246
PrepareAny(TfLiteContext * context,TfLiteNode * node)247 TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
248 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
249 const TfLiteTensor* input;
250 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
251 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool);
252 return PrepareSimple(context, node);
253 }
254
PrepareMeanOrSum(TfLiteContext * context,TfLiteNode * node)255 TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
256 TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
257 OpData* data = reinterpret_cast<OpData*>(node->user_data);
258
259 // reduce_mean requires a buffer to store intermediate sum result.
260 OpContext op_context(context, node);
261 if (op_context.input->type == kTfLiteInt8 ||
262 op_context.input->type == kTfLiteUInt8 ||
263 op_context.input->type == kTfLiteInt16) {
264 const double real_multiplier =
265 static_cast<double>(op_context.input->params.scale) /
266 static_cast<double>(op_context.output->params.scale);
267 int exponent;
268 QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
269 data->shift = exponent;
270 }
271
272 if (op_context.input->type == kTfLiteInt16) {
273 TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
274 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, 0);
275 }
276
277 TfLiteTensor* temp_sum;
278 TF_LITE_ENSURE_OK(context,
279 GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
280 if (!IsConstantTensor(op_context.axis)) {
281 SetTensorToDynamic(temp_sum);
282 return kTfLiteOk;
283 }
284 temp_sum->allocation_type = kTfLiteArenaRw;
285 return ResizeTempSum(context, &op_context, temp_sum);
286 }
287
ResolveAxis(const int * axis_data,int axis_count,tflite::MeanParams * op_params)288 void ResolveAxis(const int* axis_data, int axis_count,
289 tflite::MeanParams* op_params) {
290 int i = 0;
291 for (; i < axis_count; ++i) {
292 op_params->axis[i] = static_cast<int16>(axis_data[i]);
293 }
294 for (; i < 4; ++i) {
295 op_params->axis[i] = 1;
296 }
297 }
298
299 template <typename integer_type>
EvalMeanReferenceOps(TfLiteContext * context,const OpContext & op_context,int num_axis,OpData * data,TfLiteTensor * temp_index,TfLiteTensor * resolved_axis,TfLiteTensor * temp_sum)300 TfLiteStatus EvalMeanReferenceOps(TfLiteContext* context,
301 const OpContext& op_context, int num_axis,
302 OpData* data, TfLiteTensor* temp_index,
303 TfLiteTensor* resolved_axis,
304 TfLiteTensor* temp_sum) {
305 tflite::MeanParams op_params;
306 op_params.axis_count = num_axis;
307 ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
308 const TfLiteTensor* input = op_context.input;
309 // Return early when input shape has zero dim.
310 for (int i = 0; i < input->dims->size; ++i) {
311 if (input->dims->data[i] == 0) return kTfLiteOk;
312 }
313
314 // TODO(b/139102329): Handle all the cases in the combined reference
315 // method.
316 if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
317 op_params.axis_count == 2 &&
318 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
319 (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
320 if (std::is_same<integer_type, uint8_t>::value) {
321 reference_ops::Mean(op_params, GetTensorShape(op_context.input),
322 GetTensorData<uint8_t>(op_context.input),
323 op_context.input->params.zero_point,
324 op_context.input->params.scale,
325 GetTensorShape(op_context.output),
326 GetTensorData<uint8_t>(op_context.output),
327 op_context.output->params.zero_point,
328 op_context.output->params.scale);
329 } else {
330 reference_integer_ops::Mean(
331 op_params, data->multiplier, data->shift, GetTensorShape(input),
332 GetTensorData<integer_type>(input),
333 op_context.input->params.zero_point,
334 GetTensorShape(op_context.output),
335 GetTensorData<integer_type>(op_context.output),
336 op_context.output->params.zero_point);
337 }
338 } else if (input->params.zero_point == op_context.output->params.zero_point &&
339 input->params.scale == op_context.output->params.scale) {
340 TF_LITE_ENSURE(
341 context,
342 reference_ops::Mean(
343 GetTensorData<integer_type>(input), input->dims->data,
344 input->dims->size, GetTensorData<integer_type>(op_context.output),
345 op_context.output->dims->data, op_context.output->dims->size,
346 GetTensorData<int>(op_context.axis), num_axis,
347 op_context.params->keep_dims, GetTensorData<int>(temp_index),
348 GetTensorData<int>(resolved_axis), GetTensorData<int>(temp_sum)));
349 } else {
350 TF_LITE_ENSURE(
351 context,
352 reference_ops::QuantizedMeanOrSum<>(
353 GetTensorData<integer_type>(input), input->params.zero_point,
354 input->params.scale, input->dims->data, input->dims->size,
355 GetTensorData<integer_type>(op_context.output),
356 op_context.output->params.zero_point,
357 op_context.output->params.scale, op_context.output->dims->data,
358 op_context.output->dims->size, GetTensorData<int>(op_context.axis),
359 num_axis, op_context.params->keep_dims,
360 GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
361 GetTensorData<int>(temp_sum),
362 /*compute_sum=*/false));
363 }
364 return kTfLiteOk;
365 }
366
367 template <KernelType kernel_type>
EvalMean(TfLiteContext * context,TfLiteNode * node)368 TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
369 OpContext op_context(context, node);
370 OpData* data = reinterpret_cast<OpData*>(node->user_data);
371
372 int num_axis = static_cast<int>(NumElements(op_context.axis));
373 TfLiteTensor* temp_index;
374 TF_LITE_ENSURE_OK(context,
375 GetTemporarySafe(context, node, /*index=*/0, &temp_index));
376 TfLiteTensor* resolved_axis;
377 TF_LITE_ENSURE_OK(
378 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
379 TfLiteTensor* temp_sum;
380 TF_LITE_ENSURE_OK(context,
381 GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
382 // Resize the output tensor if the output tensor is dynamic.
383 if (IsDynamicTensor(op_context.output)) {
384 TF_LITE_ENSURE_OK(context,
385 ResizeTempAxis(context, &op_context, resolved_axis));
386 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
387 TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
388 }
389
390 // Return early when input shape has zero dim.
391 const TfLiteTensor* input = op_context.input;
392 for (int i = 0; i < input->dims->size; ++i) {
393 if (input->dims->data[i] == 0) return kTfLiteOk;
394 }
395
396 if (kernel_type == kGenericOptimized) {
397 // Use optimized ops if available.
398 switch (input->type) {
399 case kTfLiteInt8: {
400 tflite::MeanParams op_params;
401 op_params.axis_count = num_axis;
402 ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
403 if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
404 op_params.axis_count == 2 &&
405 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
406 (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
407 optimized_integer_ops::Mean(
408 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
409 input->params.zero_point, input->params.scale,
410 GetTensorShape(op_context.output),
411 GetTensorData<int8_t>(op_context.output),
412 op_context.output->params.zero_point,
413 op_context.output->params.scale,
414 CpuBackendContext::GetFromContext(context));
415 return kTfLiteOk;
416 }
417 } break;
418 case kTfLiteUInt8: {
419 tflite::MeanParams op_params;
420 op_params.axis_count = num_axis;
421 ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
422 if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
423 op_params.axis_count == 2 &&
424 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
425 (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
426 optimized_ops::Mean(op_params, GetTensorShape(input),
427 GetTensorData<uint8_t>(input),
428 input->params.zero_point, input->params.scale,
429 GetTensorShape(op_context.output),
430 GetTensorData<uint8_t>(op_context.output),
431 op_context.output->params.zero_point,
432 op_context.output->params.scale,
433 CpuBackendContext::GetFromContext(context));
434 return kTfLiteOk;
435 }
436 } break;
437 default:
438 break;
439 }
440 }
441
442 // From here, it uses the reference implementations.
443 // TODO(b/139102329): Clean up the function signatures to merge the variations
444 // and handle the specialized cases in the combined reference implementations
445 // per each op.
446 switch (op_context.input->type) {
447 case kTfLiteFloat32: {
448 tflite::MeanParams op_params;
449 op_params.axis_count = num_axis;
450 ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
451 const TfLiteTensor* input = op_context.input;
452 // TODO(b/139102329): Handle the below special case in the combined
453 // reference method.
454 // Defer to specialized implementation for 4D Mean across axes 1 & 2.
455 if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
456 op_params.axis_count == 2 &&
457 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
458 (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
459 reference_ops::Mean(op_params, GetTensorShape(input),
460 GetTensorData<float>(input),
461 GetTensorShape(op_context.output),
462 GetTensorData<float>(op_context.output));
463 } else {
464 TF_LITE_ENSURE(
465 context,
466 optimized_ops::MeanGeneral(
467 GetTensorData<float>(op_context.input),
468 op_context.input->dims->data, op_context.input->dims->size,
469 GetTensorData<float>(op_context.output),
470 op_context.output->dims->data, op_context.output->dims->size,
471 GetTensorData<int>(op_context.axis), num_axis,
472 op_context.params->keep_dims, GetTensorData<int>(temp_index),
473 GetTensorData<int>(resolved_axis),
474 GetTensorData<float>(temp_sum)));
475 }
476 } break;
477 case kTfLiteInt32:
478 TF_LITE_ENSURE(
479 context,
480 reference_ops::Mean(
481 GetTensorData<int>(op_context.input),
482 op_context.input->dims->data, op_context.input->dims->size,
483 GetTensorData<int>(op_context.output),
484 op_context.output->dims->data, op_context.output->dims->size,
485 GetTensorData<int>(op_context.axis), num_axis,
486 op_context.params->keep_dims, GetTensorData<int>(temp_index),
487 GetTensorData<int>(resolved_axis),
488 GetTensorData<int64_t>(temp_sum)));
489 break;
490 case kTfLiteInt64:
491 TF_LITE_ENSURE(
492 context,
493 reference_ops::Mean(
494 GetTensorData<int64_t>(op_context.input),
495 op_context.input->dims->data, op_context.input->dims->size,
496 GetTensorData<int64_t>(op_context.output),
497 op_context.output->dims->data, op_context.output->dims->size,
498 GetTensorData<int>(op_context.axis), num_axis,
499 op_context.params->keep_dims, GetTensorData<int>(temp_index),
500 GetTensorData<int>(resolved_axis),
501 GetTensorData<int64_t>(temp_sum)));
502 break;
503 case kTfLiteInt8: {
504 TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int8_t>(
505 context, op_context, num_axis, data,
506 temp_index, resolved_axis, temp_sum));
507 } break;
508 case kTfLiteInt16: {
509 TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int16_t>(
510 context, op_context, num_axis, data,
511 temp_index, resolved_axis, temp_sum));
512 } break;
513 case kTfLiteUInt8: {
514 TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<uint8_t>(
515 context, op_context, num_axis, data,
516 temp_index, resolved_axis, temp_sum));
517 } break;
518 default:
519 return kTfLiteError;
520 }
521 return kTfLiteOk;
522 }
523
524 // The underlying logic for Reduce Sum/Prod/Max/Min/Any
525 template <typename T>
EvalLogic(TfLiteContext * context,TfLiteNode * node,OpContext * op_context,T init_value,T reducer (const T current,const T in))526 TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
527 OpContext* op_context, T init_value,
528 T reducer(const T current, const T in)) {
529 int64_t num_axis = NumElements(op_context->axis);
530 TfLiteTensor* temp_index;
531 TF_LITE_ENSURE_OK(context,
532 GetTemporarySafe(context, node, /*index=*/0, &temp_index));
533 TfLiteTensor* resolved_axis;
534 TF_LITE_ENSURE_OK(
535 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
536 // Resize the output tensor if the output tensor is dynamic.
537 if (IsDynamicTensor(op_context->output)) {
538 TF_LITE_ENSURE_OK(context,
539 ResizeTempAxis(context, op_context, resolved_axis));
540 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
541 }
542
543 const TfLiteTensor* input = op_context->input;
544 // Return early when input shape has zero dim.
545 for (int i = 0; i < input->dims->size; ++i) {
546 if (input->dims->data[i] == 0) return kTfLiteOk;
547 }
548
549 if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
550 input->type == kTfLiteInt16) {
551 TF_LITE_ENSURE_EQ(context, input->params.scale,
552 op_context->output->params.scale);
553 TF_LITE_ENSURE_EQ(context, input->params.zero_point,
554 op_context->output->params.zero_point);
555 }
556 TF_LITE_ENSURE(
557 context,
558 reference_ops::ReduceGeneric<T>(
559 GetTensorData<T>(input), input->dims->data, input->dims->size,
560 GetTensorData<T>(op_context->output), op_context->output->dims->data,
561 op_context->output->dims->size, GetTensorData<int>(op_context->axis),
562 num_axis, op_context->params->keep_dims,
563 GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
564 init_value, reducer));
565 return kTfLiteOk;
566 }
567
568 enum ReduceType {
569 kSum,
570 kProd,
571 kMax,
572 kMin,
573 kAny,
574 };
575
576 // Eval for determined input type and reduce type.
577 template <typename T>
EvalType(TfLiteContext * context,TfLiteNode * node,OpContext * op_context,ReduceType reduce_type)578 TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
579 OpContext* op_context, ReduceType reduce_type) {
580 switch (reduce_type) {
581 case kSum:
582 return EvalLogic<T>(
583 context, node, op_context, static_cast<T>(0),
584 [](const T current, const T in) -> T { return in + current; });
585 break;
586 case kProd:
587 return EvalLogic<T>(
588 context, node, op_context, static_cast<T>(1),
589 [](const T current, const T in) -> T { return in * current; });
590 break;
591 case kMax:
592 return EvalLogic<T>(context, node, op_context,
593 std::numeric_limits<T>::lowest(),
594 [](const T current, const T in) -> T {
595 return (in > current) ? in : current;
596 });
597 break;
598 case kMin:
599 return EvalLogic<T>(context, node, op_context,
600 std::numeric_limits<T>::max(),
601 [](const T current, const T in) -> T {
602 return (in < current) ? in : current;
603 });
604 break;
605 default:
606 return kTfLiteError;
607 }
608 }
609
610 // Template specialization for bool type
611 template <>
EvalType(TfLiteContext * context,TfLiteNode * node,OpContext * op_context,ReduceType reduce_type)612 TfLiteStatus EvalType<bool>(TfLiteContext* context, TfLiteNode* node,
613 OpContext* op_context, ReduceType reduce_type) {
614 switch (reduce_type) {
615 case kAny:
616 return EvalLogic<bool>(context, node, op_context, false,
617 [](const bool current, const bool in) -> bool {
618 return in || current;
619 });
620 break;
621 default:
622 return kTfLiteError;
623 }
624 }
625
626 // The entry point that handles input types and then calls template functions to
627 // handle ReduceType.
628 template <KernelType kernel_type, ReduceType reduce_type>
EvalGeneric(TfLiteContext * context,TfLiteNode * node)629 TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
630 if (kernel_type != kReference) {
631 return kTfLiteOk;
632 }
633 OpContext op_context(context, node);
634 switch (op_context.input->type) {
635 case kTfLiteFloat32:
636 return EvalType<float>(context, node, &op_context, reduce_type);
637 break;
638 case kTfLiteInt32:
639 return EvalType<int>(context, node, &op_context, reduce_type);
640 break;
641 case kTfLiteInt64:
642 return EvalType<int64_t>(context, node, &op_context, reduce_type);
643 break;
644 case kTfLiteUInt8:
645 return EvalType<uint8_t>(context, node, &op_context, reduce_type);
646 break;
647 case kTfLiteInt8:
648 return EvalType<int8_t>(context, node, &op_context, reduce_type);
649 break;
650 case kTfLiteInt16:
651 return EvalType<int16_t>(context, node, &op_context, reduce_type);
652 break;
653 case kTfLiteBool:
654 return EvalType<bool>(context, node, &op_context, reduce_type);
655 break;
656 default:
657 return kTfLiteError;
658 }
659 }
660
EvalSum(TfLiteContext * context,TfLiteNode * node)661 TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
662 OpContext op_context(context, node);
663 ruy::profiler::ScopeLabel label("Sum");
664 const auto& input = op_context.input;
665 const auto& output = op_context.output;
666 const bool same_scale =
667 (input->params.scale == output->params.scale &&
668 input->params.zero_point == output->params.zero_point);
669 const bool eight_bit_quantized =
670 input->type == kTfLiteUInt8 || input->type == kTfLiteInt8;
671 const bool need_rescale = (eight_bit_quantized && !same_scale);
672 if (need_rescale) {
673 // Rescaling 8bit reduce sum.
674 int num_axis = static_cast<int>(NumElements(op_context.axis));
675 TfLiteTensor* temp_index;
676 TF_LITE_ENSURE_OK(
677 context, GetTemporarySafe(context, node, /*index=*/0, &temp_index));
678 TfLiteTensor* resolved_axis;
679 TF_LITE_ENSURE_OK(
680 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
681 TfLiteTensor* temp_sum;
682 TF_LITE_ENSURE_OK(context,
683 GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
684 // Resize the output tensor if the output tensor is dynamic.
685 if (IsDynamicTensor(op_context.output)) {
686 TF_LITE_ENSURE_OK(context,
687 ResizeTempAxis(context, &op_context, resolved_axis));
688 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
689 TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
690 }
691 // Return early when input shape has zero dim.
692 for (int i = 0; i < input->dims->size; ++i) {
693 if (input->dims->data[i] == 0) return kTfLiteOk;
694 }
695
696 if (input->type == kTfLiteUInt8) {
697 TF_LITE_ENSURE(
698 context,
699 reference_ops::QuantizedMeanOrSum<>(
700 GetTensorData<uint8_t>(op_context.input),
701 op_context.input->params.zero_point,
702 op_context.input->params.scale, op_context.input->dims->data,
703 op_context.input->dims->size,
704 GetTensorData<uint8_t>(op_context.output),
705 op_context.output->params.zero_point,
706 op_context.output->params.scale, op_context.output->dims->data,
707 op_context.output->dims->size,
708 GetTensorData<int>(op_context.axis), num_axis,
709 op_context.params->keep_dims, GetTensorData<int>(temp_index),
710 GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
711 /*compute_sum=*/true));
712 }
713 if (input->type == kTfLiteInt8) {
714 TF_LITE_ENSURE(
715 context,
716 reference_ops::QuantizedMeanOrSum<>(
717 GetTensorData<int8_t>(op_context.input),
718 op_context.input->params.zero_point,
719 op_context.input->params.scale, op_context.input->dims->data,
720 op_context.input->dims->size,
721 GetTensorData<int8_t>(op_context.output),
722 op_context.output->params.zero_point,
723 op_context.output->params.scale, op_context.output->dims->data,
724 op_context.output->dims->size,
725 GetTensorData<int>(op_context.axis), num_axis,
726 op_context.params->keep_dims, GetTensorData<int>(temp_index),
727 GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
728 /*compute_sum=*/true));
729 }
730 } else {
731 return EvalGeneric<kReference, kSum>(context, node);
732 }
733
734 return kTfLiteOk;
735 }
736 } // namespace reduce
737
Register_MEAN_OPT()738 TfLiteRegistration* Register_MEAN_OPT() {
739 static TfLiteRegistration r = {reduce::Init, reduce::Free,
740 reduce::PrepareMeanOrSum,
741 reduce::EvalMean<reduce::kGenericOptimized>};
742 return &r;
743 }
744
Register_MEAN_REF()745 TfLiteRegistration* Register_MEAN_REF() {
746 static TfLiteRegistration r = {reduce::Init, reduce::Free,
747 reduce::PrepareMeanOrSum,
748 reduce::EvalMean<reduce::kReference>};
749 return &r;
750 }
751
Register_SUM_REF()752 TfLiteRegistration* Register_SUM_REF() {
753 static TfLiteRegistration r = {reduce::Init, reduce::Free,
754 reduce::PrepareMeanOrSum, reduce::EvalSum};
755 return &r;
756 }
757
Register_REDUCE_PROD_REF()758 TfLiteRegistration* Register_REDUCE_PROD_REF() {
759 static TfLiteRegistration r = {
760 reduce::Init, reduce::Free, reduce::PrepareSimple,
761 reduce::EvalGeneric<reduce::kReference, reduce::kProd>};
762 return &r;
763 }
764
Register_REDUCE_MAX_REF()765 TfLiteRegistration* Register_REDUCE_MAX_REF() {
766 static TfLiteRegistration r = {
767 reduce::Init, reduce::Free, reduce::PrepareSimple,
768 reduce::EvalGeneric<reduce::kReference, reduce::kMax>};
769 return &r;
770 }
771
Register_REDUCE_MIN_REF()772 TfLiteRegistration* Register_REDUCE_MIN_REF() {
773 static TfLiteRegistration r = {
774 reduce::Init, reduce::Free, reduce::PrepareSimple,
775 reduce::EvalGeneric<reduce::kReference, reduce::kMin>};
776 return &r;
777 }
778
Register_REDUCE_ANY_REF()779 TfLiteRegistration* Register_REDUCE_ANY_REF() {
780 static TfLiteRegistration r = {
781 reduce::Init, reduce::Free, reduce::PrepareAny,
782 reduce::EvalGeneric<reduce::kReference, reduce::kAny>};
783 return &r;
784 }
785
Register_MEAN()786 TfLiteRegistration* Register_MEAN() {
787 #ifdef USE_NEON
788 return Register_MEAN_OPT();
789 #else
790 return Register_MEAN_REF();
791 #endif
792 }
793
Register_SUM()794 TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
Register_REDUCE_PROD()795 TfLiteRegistration* Register_REDUCE_PROD() {
796 return Register_REDUCE_PROD_REF();
797 }
Register_REDUCE_MAX()798 TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
Register_REDUCE_MIN()799 TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_REF(); }
Register_REDUCE_ANY()800 TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_REF(); }
801
802 } // namespace builtin
803 } // namespace ops
804 } // namespace tflite
805