• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
16 
17 #include <stddef.h>
18 
19 #include <cstdint>
20 #include <limits>
21 
22 #include "ruy/profiler/instrumentation.h"  // from @ruy
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/cpu_backend_context.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30 #include "tensorflow/lite/kernels/internal/quantization_util.h"
31 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
32 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
33 #include "tensorflow/lite/kernels/internal/tensor.h"
34 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36 #include "tensorflow/lite/kernels/kernel_util.h"
37 
38 namespace tflite {
39 namespace ops {
40 namespace builtin {
41 namespace reduce {
42 
43 // This file has reference implementation of reduce_* operators.
44 enum KernelType {
45   kReference,
46   kGenericOptimized,
47 };
48 
49 struct OpData {
50   int32_t multiplier;
51   int shift;
52   // The index of the temporary tensor where the quantized inputs are cached.
53   int scratch_tensor_index;
54 };
55 
56 struct OpContext {
OpContexttflite::ops::builtin::reduce::OpContext57   OpContext(TfLiteContext* context, TfLiteNode* node) {
58     params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
59     input = GetInput(context, node, 0);
60     axis = GetInput(context, node, 1);
61     output = GetOutput(context, node, 0);
62   }
63   TfLiteReducerParams* params;
64   const TfLiteTensor* input;
65   const TfLiteTensor* axis;
66   TfLiteTensor* output;
67 };
68 
Init(TfLiteContext * context,const char * buffer,size_t length)69 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
70   // Creates two temp tensors to store index and axis for internal
71   // implementation only.
72   auto* op_data = new OpData();
73   context->AddTensors(context, 3, &op_data->scratch_tensor_index);
74   return op_data;
75 }
76 
Free(TfLiteContext * context,void * buffer)77 void Free(TfLiteContext* context, void* buffer) {
78   delete reinterpret_cast<OpData*>(buffer);
79 }
80 
81 // Resizes the temp tensor that stores resolved axis.
ResizeTempAxis(TfLiteContext * context,OpContext * op_context,TfLiteTensor * resolved_axis)82 TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context,
83                             TfLiteTensor* resolved_axis) {
84   TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1);
85   axis_size->data[0] = static_cast<int>(NumElements(op_context->axis));
86   return context->ResizeTensor(context, resolved_axis, axis_size);
87 }
88 
89 // Resizes the temp tensor that stores temp sum of reduced elements.
ResizeTempSum(TfLiteContext * context,OpContext * op_context,TfLiteTensor * temp_sum)90 TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context,
91                            TfLiteTensor* temp_sum) {
92   TfLiteIntArray* size = TfLiteIntArrayCreate(1);
93   size->data[0] = static_cast<int>(NumElements(op_context->output));
94   return context->ResizeTensor(context, temp_sum, size);
95 }
96 
97 // Resizes output array based on the input size and resolved axis.
ResizeOutputTensor(TfLiteContext * context,OpContext * op_context)98 TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
99   size_t num_axis = NumElements(op_context->axis);
100   const TfLiteIntArray* input_dims = op_context->input->dims;
101   int input_num_dims = NumDimensions(op_context->input);
102   if (input_num_dims == 0) {
103     return context->ResizeTensor(context, op_context->output,
104                                  TfLiteIntArrayCreate(0));
105   }
106   const int* axis = GetTensorData<int>(op_context->axis);
107   if (op_context->params->keep_dims) {
108     TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
109     for (int idx = 0; idx < input_num_dims; ++idx) {
110       bool is_axis = false;
111       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
112         if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
113           is_axis = true;
114           break;
115         }
116       }
117       if (is_axis) {
118         output_dims->data[idx] = 1;
119       } else {
120         output_dims->data[idx] = input_dims->data[idx];
121       }
122     }
123     return context->ResizeTensor(context, op_context->output, output_dims);
124   } else {
125     // Calculates size of reducing axis.
126     int num_reduce_axis = num_axis;
127     for (int i = 0; i < num_axis; ++i) {
128       int current = axis[i];
129       if (current < 0) {
130         current += input_num_dims;
131       }
132       TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims);
133       for (int j = 0; j < i; ++j) {
134         int previous = axis[j];
135         if (previous < 0) {
136           previous += input_num_dims;
137         }
138         if (current == previous) {
139           --num_reduce_axis;
140           break;
141         }
142       }
143     }
144     // Determines output dimensions.
145     TfLiteIntArray* output_dims =
146         TfLiteIntArrayCreate(input_num_dims - num_reduce_axis);
147     int num_skip_axis = 0;
148     for (int idx = 0; idx < input_num_dims; ++idx) {
149       bool is_axis = false;
150       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
151         if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
152           ++num_skip_axis;
153           is_axis = true;
154           break;
155         }
156       }
157       if (!is_axis) {
158         output_dims->data[idx - num_skip_axis] = input_dims->data[idx];
159       }
160     }
161     return context->ResizeTensor(context, op_context->output, output_dims);
162   }
163 }
164 
165 // Initializes temp tensors to store index and resolved axis.
InitializeTemporaries(TfLiteContext * context,TfLiteNode * node,OpContext * op_context)166 TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
167                                    OpContext* op_context) {
168   // Creates a temp index to iterate through input data.
169   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
170   TfLiteIntArrayFree(node->temporaries);
171   node->temporaries = TfLiteIntArrayCreate(3);
172   node->temporaries->data[0] = op_data->scratch_tensor_index;
173   TfLiteTensor* scratch_tensor;
174   TF_LITE_ENSURE_OK(
175       context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
176   scratch_tensor->type = kTfLiteInt32;
177   scratch_tensor->allocation_type = kTfLiteArenaRw;
178   TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
179   index_size->data[0] = NumDimensions(op_context->input);
180   TF_LITE_ENSURE_OK(context,
181                     context->ResizeTensor(context, scratch_tensor, index_size));
182 
183   // Creates a temp tensor to store resolved axis given input data.
184   node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
185   TfLiteTensor* resolved_axis;
186   TF_LITE_ENSURE_OK(
187       context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
188   resolved_axis->type = kTfLiteInt32;
189   // Creates a temp tensor to store temp sums when calculating mean.
190   node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
191   TfLiteTensor* temp_sum;
192   TF_LITE_ENSURE_OK(context,
193                     GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
194   switch (op_context->input->type) {
195     case kTfLiteFloat32:
196       temp_sum->type = kTfLiteFloat32;
197       break;
198     case kTfLiteInt32:
199       temp_sum->type = kTfLiteInt64;
200       break;
201     case kTfLiteInt64:
202       temp_sum->type = kTfLiteInt64;
203       break;
204     case kTfLiteUInt8:
205     case kTfLiteInt8:
206     case kTfLiteInt16:
207       temp_sum->type = kTfLiteInt32;
208       break;
209     case kTfLiteBool:
210       temp_sum->type = kTfLiteBool;
211       break;
212     default:
213       return kTfLiteError;
214   }
215   return kTfLiteOk;
216 }
217 
PrepareSimple(TfLiteContext * context,TfLiteNode * node)218 TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
219   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
220   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
221 
222   OpContext op_context(context, node);
223   TF_LITE_ENSURE_TYPES_EQ(context, op_context.axis->type, kTfLiteInt32);
224   TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
225 
226   if (op_context.input->type == kTfLiteInt16) {
227     TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
228     TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, 0);
229   }
230 
231   TfLiteTensor* resolved_axis;
232   TF_LITE_ENSURE_OK(
233       context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
234   // Leaves work to Eval if axis is not constant; else resizes output.
235   if (!IsConstantTensor(op_context.axis)) {
236     SetTensorToDynamic(op_context.output);
237     SetTensorToDynamic(resolved_axis);
238     return kTfLiteOk;
239   }
240   resolved_axis->allocation_type = kTfLiteArenaRw;
241   TF_LITE_ENSURE_OK(context,
242                     ResizeTempAxis(context, &op_context, resolved_axis));
243   TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
244   return kTfLiteOk;
245 }
246 
PrepareAny(TfLiteContext * context,TfLiteNode * node)247 TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
248   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
249   const TfLiteTensor* input;
250   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
251   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool);
252   return PrepareSimple(context, node);
253 }
254 
PrepareMeanOrSum(TfLiteContext * context,TfLiteNode * node)255 TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
256   TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
257   OpData* data = reinterpret_cast<OpData*>(node->user_data);
258 
259   // reduce_mean requires a buffer to store intermediate sum result.
260   OpContext op_context(context, node);
261   if (op_context.input->type == kTfLiteInt8 ||
262       op_context.input->type == kTfLiteUInt8 ||
263       op_context.input->type == kTfLiteInt16) {
264     const double real_multiplier =
265         static_cast<double>(op_context.input->params.scale) /
266         static_cast<double>(op_context.output->params.scale);
267     int exponent;
268     QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
269     data->shift = exponent;
270   }
271 
272   if (op_context.input->type == kTfLiteInt16) {
273     TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
274     TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, 0);
275   }
276 
277   TfLiteTensor* temp_sum;
278   TF_LITE_ENSURE_OK(context,
279                     GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
280   if (!IsConstantTensor(op_context.axis)) {
281     SetTensorToDynamic(temp_sum);
282     return kTfLiteOk;
283   }
284   temp_sum->allocation_type = kTfLiteArenaRw;
285   return ResizeTempSum(context, &op_context, temp_sum);
286 }
287 
ResolveAxis(const int * axis_data,int axis_count,tflite::MeanParams * op_params)288 void ResolveAxis(const int* axis_data, int axis_count,
289                  tflite::MeanParams* op_params) {
290   int i = 0;
291   for (; i < axis_count; ++i) {
292     op_params->axis[i] = static_cast<int16>(axis_data[i]);
293   }
294   for (; i < 4; ++i) {
295     op_params->axis[i] = 1;
296   }
297 }
298 
299 template <typename integer_type>
EvalMeanReferenceOps(TfLiteContext * context,const OpContext & op_context,int num_axis,OpData * data,TfLiteTensor * temp_index,TfLiteTensor * resolved_axis,TfLiteTensor * temp_sum)300 TfLiteStatus EvalMeanReferenceOps(TfLiteContext* context,
301                                   const OpContext& op_context, int num_axis,
302                                   OpData* data, TfLiteTensor* temp_index,
303                                   TfLiteTensor* resolved_axis,
304                                   TfLiteTensor* temp_sum) {
305   tflite::MeanParams op_params;
306   op_params.axis_count = num_axis;
307   ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
308   const TfLiteTensor* input = op_context.input;
309   // Return early when input shape has zero dim.
310   for (int i = 0; i < input->dims->size; ++i) {
311     if (input->dims->data[i] == 0) return kTfLiteOk;
312   }
313 
314   // TODO(b/139102329): Handle all the cases in the combined reference
315   // method.
316   if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
317       op_params.axis_count == 2 &&
318       ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
319        (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
320     if (std::is_same<integer_type, uint8_t>::value) {
321       reference_ops::Mean(op_params, GetTensorShape(op_context.input),
322                           GetTensorData<uint8_t>(op_context.input),
323                           op_context.input->params.zero_point,
324                           op_context.input->params.scale,
325                           GetTensorShape(op_context.output),
326                           GetTensorData<uint8_t>(op_context.output),
327                           op_context.output->params.zero_point,
328                           op_context.output->params.scale);
329     } else {
330       reference_integer_ops::Mean(
331           op_params, data->multiplier, data->shift, GetTensorShape(input),
332           GetTensorData<integer_type>(input),
333           op_context.input->params.zero_point,
334           GetTensorShape(op_context.output),
335           GetTensorData<integer_type>(op_context.output),
336           op_context.output->params.zero_point);
337     }
338   } else if (input->params.zero_point == op_context.output->params.zero_point &&
339              input->params.scale == op_context.output->params.scale) {
340     TF_LITE_ENSURE(
341         context,
342         reference_ops::Mean(
343             GetTensorData<integer_type>(input), input->dims->data,
344             input->dims->size, GetTensorData<integer_type>(op_context.output),
345             op_context.output->dims->data, op_context.output->dims->size,
346             GetTensorData<int>(op_context.axis), num_axis,
347             op_context.params->keep_dims, GetTensorData<int>(temp_index),
348             GetTensorData<int>(resolved_axis), GetTensorData<int>(temp_sum)));
349   } else {
350     TF_LITE_ENSURE(
351         context,
352         reference_ops::QuantizedMeanOrSum<>(
353             GetTensorData<integer_type>(input), input->params.zero_point,
354             input->params.scale, input->dims->data, input->dims->size,
355             GetTensorData<integer_type>(op_context.output),
356             op_context.output->params.zero_point,
357             op_context.output->params.scale, op_context.output->dims->data,
358             op_context.output->dims->size, GetTensorData<int>(op_context.axis),
359             num_axis, op_context.params->keep_dims,
360             GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
361             GetTensorData<int>(temp_sum),
362             /*compute_sum=*/false));
363   }
364   return kTfLiteOk;
365 }
366 
367 template <KernelType kernel_type>
EvalMean(TfLiteContext * context,TfLiteNode * node)368 TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
369   OpContext op_context(context, node);
370   OpData* data = reinterpret_cast<OpData*>(node->user_data);
371 
372   int num_axis = static_cast<int>(NumElements(op_context.axis));
373   TfLiteTensor* temp_index;
374   TF_LITE_ENSURE_OK(context,
375                     GetTemporarySafe(context, node, /*index=*/0, &temp_index));
376   TfLiteTensor* resolved_axis;
377   TF_LITE_ENSURE_OK(
378       context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
379   TfLiteTensor* temp_sum;
380   TF_LITE_ENSURE_OK(context,
381                     GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
382   // Resize the output tensor if the output tensor is dynamic.
383   if (IsDynamicTensor(op_context.output)) {
384     TF_LITE_ENSURE_OK(context,
385                       ResizeTempAxis(context, &op_context, resolved_axis));
386     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
387     TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
388   }
389 
390   // Return early when input shape has zero dim.
391   const TfLiteTensor* input = op_context.input;
392   for (int i = 0; i < input->dims->size; ++i) {
393     if (input->dims->data[i] == 0) return kTfLiteOk;
394   }
395 
396   if (kernel_type == kGenericOptimized) {
397     // Use optimized ops if available.
398     switch (input->type) {
399       case kTfLiteInt8: {
400         tflite::MeanParams op_params;
401         op_params.axis_count = num_axis;
402         ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
403         if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
404             op_params.axis_count == 2 &&
405             ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
406              (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
407           optimized_integer_ops::Mean(
408               op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
409               input->params.zero_point, input->params.scale,
410               GetTensorShape(op_context.output),
411               GetTensorData<int8_t>(op_context.output),
412               op_context.output->params.zero_point,
413               op_context.output->params.scale,
414               CpuBackendContext::GetFromContext(context));
415           return kTfLiteOk;
416         }
417       } break;
418       case kTfLiteUInt8: {
419         tflite::MeanParams op_params;
420         op_params.axis_count = num_axis;
421         ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
422         if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
423             op_params.axis_count == 2 &&
424             ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
425              (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
426           optimized_ops::Mean(op_params, GetTensorShape(input),
427                               GetTensorData<uint8_t>(input),
428                               input->params.zero_point, input->params.scale,
429                               GetTensorShape(op_context.output),
430                               GetTensorData<uint8_t>(op_context.output),
431                               op_context.output->params.zero_point,
432                               op_context.output->params.scale,
433                               CpuBackendContext::GetFromContext(context));
434           return kTfLiteOk;
435         }
436       } break;
437       default:
438         break;
439     }
440   }
441 
442   // From here, it uses the reference implementations.
443   // TODO(b/139102329): Clean up the function signatures to merge the variations
444   // and handle the specialized cases in the combined reference implementations
445   // per each op.
446   switch (op_context.input->type) {
447     case kTfLiteFloat32: {
448       tflite::MeanParams op_params;
449       op_params.axis_count = num_axis;
450       ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
451       const TfLiteTensor* input = op_context.input;
452       // TODO(b/139102329): Handle the below special case in the combined
453       // reference method.
454       // Defer to specialized implementation for 4D Mean across axes 1 & 2.
455       if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
456           op_params.axis_count == 2 &&
457           ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
458            (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
459         reference_ops::Mean(op_params, GetTensorShape(input),
460                             GetTensorData<float>(input),
461                             GetTensorShape(op_context.output),
462                             GetTensorData<float>(op_context.output));
463       } else {
464         TF_LITE_ENSURE(
465             context,
466             optimized_ops::MeanGeneral(
467                 GetTensorData<float>(op_context.input),
468                 op_context.input->dims->data, op_context.input->dims->size,
469                 GetTensorData<float>(op_context.output),
470                 op_context.output->dims->data, op_context.output->dims->size,
471                 GetTensorData<int>(op_context.axis), num_axis,
472                 op_context.params->keep_dims, GetTensorData<int>(temp_index),
473                 GetTensorData<int>(resolved_axis),
474                 GetTensorData<float>(temp_sum)));
475       }
476     } break;
477     case kTfLiteInt32:
478       TF_LITE_ENSURE(
479           context,
480           reference_ops::Mean(
481               GetTensorData<int>(op_context.input),
482               op_context.input->dims->data, op_context.input->dims->size,
483               GetTensorData<int>(op_context.output),
484               op_context.output->dims->data, op_context.output->dims->size,
485               GetTensorData<int>(op_context.axis), num_axis,
486               op_context.params->keep_dims, GetTensorData<int>(temp_index),
487               GetTensorData<int>(resolved_axis),
488               GetTensorData<int64_t>(temp_sum)));
489       break;
490     case kTfLiteInt64:
491       TF_LITE_ENSURE(
492           context,
493           reference_ops::Mean(
494               GetTensorData<int64_t>(op_context.input),
495               op_context.input->dims->data, op_context.input->dims->size,
496               GetTensorData<int64_t>(op_context.output),
497               op_context.output->dims->data, op_context.output->dims->size,
498               GetTensorData<int>(op_context.axis), num_axis,
499               op_context.params->keep_dims, GetTensorData<int>(temp_index),
500               GetTensorData<int>(resolved_axis),
501               GetTensorData<int64_t>(temp_sum)));
502       break;
503     case kTfLiteInt8: {
504       TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int8_t>(
505                                      context, op_context, num_axis, data,
506                                      temp_index, resolved_axis, temp_sum));
507     } break;
508     case kTfLiteInt16: {
509       TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<int16_t>(
510                                      context, op_context, num_axis, data,
511                                      temp_index, resolved_axis, temp_sum));
512     } break;
513     case kTfLiteUInt8: {
514       TF_LITE_ENSURE_OK(context, EvalMeanReferenceOps<uint8_t>(
515                                      context, op_context, num_axis, data,
516                                      temp_index, resolved_axis, temp_sum));
517     } break;
518     default:
519       return kTfLiteError;
520   }
521   return kTfLiteOk;
522 }
523 
524 // The underlying logic for Reduce Sum/Prod/Max/Min/Any
525 template <typename T>
EvalLogic(TfLiteContext * context,TfLiteNode * node,OpContext * op_context,T init_value,T reducer (const T current,const T in))526 TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
527                        OpContext* op_context, T init_value,
528                        T reducer(const T current, const T in)) {
529   int64_t num_axis = NumElements(op_context->axis);
530   TfLiteTensor* temp_index;
531   TF_LITE_ENSURE_OK(context,
532                     GetTemporarySafe(context, node, /*index=*/0, &temp_index));
533   TfLiteTensor* resolved_axis;
534   TF_LITE_ENSURE_OK(
535       context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
536   // Resize the output tensor if the output tensor is dynamic.
537   if (IsDynamicTensor(op_context->output)) {
538     TF_LITE_ENSURE_OK(context,
539                       ResizeTempAxis(context, op_context, resolved_axis));
540     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
541   }
542 
543   const TfLiteTensor* input = op_context->input;
544   // Return early when input shape has zero dim.
545   for (int i = 0; i < input->dims->size; ++i) {
546     if (input->dims->data[i] == 0) return kTfLiteOk;
547   }
548 
549   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
550       input->type == kTfLiteInt16) {
551     TF_LITE_ENSURE_EQ(context, input->params.scale,
552                       op_context->output->params.scale);
553     TF_LITE_ENSURE_EQ(context, input->params.zero_point,
554                       op_context->output->params.zero_point);
555   }
556   TF_LITE_ENSURE(
557       context,
558       reference_ops::ReduceGeneric<T>(
559           GetTensorData<T>(input), input->dims->data, input->dims->size,
560           GetTensorData<T>(op_context->output), op_context->output->dims->data,
561           op_context->output->dims->size, GetTensorData<int>(op_context->axis),
562           num_axis, op_context->params->keep_dims,
563           GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
564           init_value, reducer));
565   return kTfLiteOk;
566 }
567 
568 enum ReduceType {
569   kSum,
570   kProd,
571   kMax,
572   kMin,
573   kAny,
574 };
575 
576 // Eval for determined input type and reduce type.
577 template <typename T>
EvalType(TfLiteContext * context,TfLiteNode * node,OpContext * op_context,ReduceType reduce_type)578 TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
579                       OpContext* op_context, ReduceType reduce_type) {
580   switch (reduce_type) {
581     case kSum:
582       return EvalLogic<T>(
583           context, node, op_context, static_cast<T>(0),
584           [](const T current, const T in) -> T { return in + current; });
585       break;
586     case kProd:
587       return EvalLogic<T>(
588           context, node, op_context, static_cast<T>(1),
589           [](const T current, const T in) -> T { return in * current; });
590       break;
591     case kMax:
592       return EvalLogic<T>(context, node, op_context,
593                           std::numeric_limits<T>::lowest(),
594                           [](const T current, const T in) -> T {
595                             return (in > current) ? in : current;
596                           });
597       break;
598     case kMin:
599       return EvalLogic<T>(context, node, op_context,
600                           std::numeric_limits<T>::max(),
601                           [](const T current, const T in) -> T {
602                             return (in < current) ? in : current;
603                           });
604       break;
605     default:
606       return kTfLiteError;
607   }
608 }
609 
610 // Template specialization for bool type
611 template <>
EvalType(TfLiteContext * context,TfLiteNode * node,OpContext * op_context,ReduceType reduce_type)612 TfLiteStatus EvalType<bool>(TfLiteContext* context, TfLiteNode* node,
613                             OpContext* op_context, ReduceType reduce_type) {
614   switch (reduce_type) {
615     case kAny:
616       return EvalLogic<bool>(context, node, op_context, false,
617                              [](const bool current, const bool in) -> bool {
618                                return in || current;
619                              });
620       break;
621     default:
622       return kTfLiteError;
623   }
624 }
625 
626 // The entry point that handles input types and then calls template functions to
627 // handle ReduceType.
628 template <KernelType kernel_type, ReduceType reduce_type>
EvalGeneric(TfLiteContext * context,TfLiteNode * node)629 TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
630   if (kernel_type != kReference) {
631     return kTfLiteOk;
632   }
633   OpContext op_context(context, node);
634   switch (op_context.input->type) {
635     case kTfLiteFloat32:
636       return EvalType<float>(context, node, &op_context, reduce_type);
637       break;
638     case kTfLiteInt32:
639       return EvalType<int>(context, node, &op_context, reduce_type);
640       break;
641     case kTfLiteInt64:
642       return EvalType<int64_t>(context, node, &op_context, reduce_type);
643       break;
644     case kTfLiteUInt8:
645       return EvalType<uint8_t>(context, node, &op_context, reduce_type);
646       break;
647     case kTfLiteInt8:
648       return EvalType<int8_t>(context, node, &op_context, reduce_type);
649       break;
650     case kTfLiteInt16:
651       return EvalType<int16_t>(context, node, &op_context, reduce_type);
652       break;
653     case kTfLiteBool:
654       return EvalType<bool>(context, node, &op_context, reduce_type);
655       break;
656     default:
657       return kTfLiteError;
658   }
659 }
660 
EvalSum(TfLiteContext * context,TfLiteNode * node)661 TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
662   OpContext op_context(context, node);
663   ruy::profiler::ScopeLabel label("Sum");
664   const auto& input = op_context.input;
665   const auto& output = op_context.output;
666   const bool same_scale =
667       (input->params.scale == output->params.scale &&
668        input->params.zero_point == output->params.zero_point);
669   const bool eight_bit_quantized =
670       input->type == kTfLiteUInt8 || input->type == kTfLiteInt8;
671   const bool need_rescale = (eight_bit_quantized && !same_scale);
672   if (need_rescale) {
673     // Rescaling 8bit reduce sum.
674     int num_axis = static_cast<int>(NumElements(op_context.axis));
675     TfLiteTensor* temp_index;
676     TF_LITE_ENSURE_OK(
677         context, GetTemporarySafe(context, node, /*index=*/0, &temp_index));
678     TfLiteTensor* resolved_axis;
679     TF_LITE_ENSURE_OK(
680         context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
681     TfLiteTensor* temp_sum;
682     TF_LITE_ENSURE_OK(context,
683                       GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
684     // Resize the output tensor if the output tensor is dynamic.
685     if (IsDynamicTensor(op_context.output)) {
686       TF_LITE_ENSURE_OK(context,
687                         ResizeTempAxis(context, &op_context, resolved_axis));
688       TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
689       TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
690     }
691     // Return early when input shape has zero dim.
692     for (int i = 0; i < input->dims->size; ++i) {
693       if (input->dims->data[i] == 0) return kTfLiteOk;
694     }
695 
696     if (input->type == kTfLiteUInt8) {
697       TF_LITE_ENSURE(
698           context,
699           reference_ops::QuantizedMeanOrSum<>(
700               GetTensorData<uint8_t>(op_context.input),
701               op_context.input->params.zero_point,
702               op_context.input->params.scale, op_context.input->dims->data,
703               op_context.input->dims->size,
704               GetTensorData<uint8_t>(op_context.output),
705               op_context.output->params.zero_point,
706               op_context.output->params.scale, op_context.output->dims->data,
707               op_context.output->dims->size,
708               GetTensorData<int>(op_context.axis), num_axis,
709               op_context.params->keep_dims, GetTensorData<int>(temp_index),
710               GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
711               /*compute_sum=*/true));
712     }
713     if (input->type == kTfLiteInt8) {
714       TF_LITE_ENSURE(
715           context,
716           reference_ops::QuantizedMeanOrSum<>(
717               GetTensorData<int8_t>(op_context.input),
718               op_context.input->params.zero_point,
719               op_context.input->params.scale, op_context.input->dims->data,
720               op_context.input->dims->size,
721               GetTensorData<int8_t>(op_context.output),
722               op_context.output->params.zero_point,
723               op_context.output->params.scale, op_context.output->dims->data,
724               op_context.output->dims->size,
725               GetTensorData<int>(op_context.axis), num_axis,
726               op_context.params->keep_dims, GetTensorData<int>(temp_index),
727               GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
728               /*compute_sum=*/true));
729     }
730   } else {
731     return EvalGeneric<kReference, kSum>(context, node);
732   }
733 
734   return kTfLiteOk;
735 }
736 }  // namespace reduce
737 
Register_MEAN_OPT()738 TfLiteRegistration* Register_MEAN_OPT() {
739   static TfLiteRegistration r = {reduce::Init, reduce::Free,
740                                  reduce::PrepareMeanOrSum,
741                                  reduce::EvalMean<reduce::kGenericOptimized>};
742   return &r;
743 }
744 
Register_MEAN_REF()745 TfLiteRegistration* Register_MEAN_REF() {
746   static TfLiteRegistration r = {reduce::Init, reduce::Free,
747                                  reduce::PrepareMeanOrSum,
748                                  reduce::EvalMean<reduce::kReference>};
749   return &r;
750 }
751 
Register_SUM_REF()752 TfLiteRegistration* Register_SUM_REF() {
753   static TfLiteRegistration r = {reduce::Init, reduce::Free,
754                                  reduce::PrepareMeanOrSum, reduce::EvalSum};
755   return &r;
756 }
757 
Register_REDUCE_PROD_REF()758 TfLiteRegistration* Register_REDUCE_PROD_REF() {
759   static TfLiteRegistration r = {
760       reduce::Init, reduce::Free, reduce::PrepareSimple,
761       reduce::EvalGeneric<reduce::kReference, reduce::kProd>};
762   return &r;
763 }
764 
Register_REDUCE_MAX_REF()765 TfLiteRegistration* Register_REDUCE_MAX_REF() {
766   static TfLiteRegistration r = {
767       reduce::Init, reduce::Free, reduce::PrepareSimple,
768       reduce::EvalGeneric<reduce::kReference, reduce::kMax>};
769   return &r;
770 }
771 
Register_REDUCE_MIN_REF()772 TfLiteRegistration* Register_REDUCE_MIN_REF() {
773   static TfLiteRegistration r = {
774       reduce::Init, reduce::Free, reduce::PrepareSimple,
775       reduce::EvalGeneric<reduce::kReference, reduce::kMin>};
776   return &r;
777 }
778 
Register_REDUCE_ANY_REF()779 TfLiteRegistration* Register_REDUCE_ANY_REF() {
780   static TfLiteRegistration r = {
781       reduce::Init, reduce::Free, reduce::PrepareAny,
782       reduce::EvalGeneric<reduce::kReference, reduce::kAny>};
783   return &r;
784 }
785 
Register_MEAN()786 TfLiteRegistration* Register_MEAN() {
787 #ifdef USE_NEON
788   return Register_MEAN_OPT();
789 #else
790   return Register_MEAN_REF();
791 #endif
792 }
793 
Register_SUM()794 TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
Register_REDUCE_PROD()795 TfLiteRegistration* Register_REDUCE_PROD() {
796   return Register_REDUCE_PROD_REF();
797 }
Register_REDUCE_MAX()798 TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
Register_REDUCE_MIN()799 TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_REF(); }
Register_REDUCE_ANY()800 TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_REF(); }
801 
802 }  // namespace builtin
803 }  // namespace ops
804 }  // namespace tflite
805