• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_OPS_H_
17 
18 #include <stdint.h>
19 
20 #include <algorithm>
21 #include <limits>
22 
23 #include "ruy/profiler/instrumentation.h"  // from @ruy
24 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops_utils.h"
26 #include "tensorflow/lite/kernels/internal/optimized/reduce_utils.h"
27 #include "tensorflow/lite/kernels/internal/reduce_common.h"
28 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
29 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
30 #include "tensorflow/lite/kernels/internal/types.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 
33 namespace tflite {
34 namespace optimized_ops {
35 
MeanImpl(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,int32 multiplier,int32 shift,int32 bias,const RuntimeShape & output_shape,uint8_t * output_data,int start_depth,int end_depth)36 inline void MeanImpl(const tflite::MeanParams& op_params,
37                      const RuntimeShape& input_shape, const uint8_t* input_data,
38                      int32 multiplier, int32 shift, int32 bias,
39                      const RuntimeShape& output_shape, uint8_t* output_data,
40                      int start_depth, int end_depth) {
41   ruy::profiler::ScopeLabel label("Mean4D/Uint8/MeanImpl");
42 
43   // Current implementation only supports dimension equals 4 and simultaneous
44   // reduction over width and height.
45   const int output_batch = output_shape.Dims(0);
46   const int output_height = output_shape.Dims(2);
47   const int output_width = output_shape.Dims(2);
48   const int input_height = input_shape.Dims(1);
49   const int input_width = input_shape.Dims(2);
50 
51   TFLITE_CHECK_EQ(op_params.axis_count, 2);
52   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
53                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
54   TFLITE_CHECK_EQ(output_height, 1);
55   TFLITE_CHECK_EQ(output_width, 1);
56 
57   constexpr int32_t kMinValue = std::numeric_limits<uint8_t>::min();
58   constexpr int32_t kMaxValue = std::numeric_limits<uint8_t>::max();
59 
60 #ifdef USE_NEON
61   const int32x4_t bias_dup = vdupq_n_s32(bias);
62   const int32x4_t min_dup = vdupq_n_s32(kMinValue);
63   const int32x4_t max_dup = vdupq_n_s32(kMaxValue);
64 #endif  // USE_NEON
65 
66   for (int out_b = 0; out_b < output_batch; ++out_b) {
67     int out_d = start_depth;
68 #ifdef USE_NEON
69 
70     for (; out_d <= end_depth - 16; out_d += 16) {
71       int32x4x4_t temp_sum;
72       temp_sum.val[0] = vdupq_n_s32(0);
73       temp_sum.val[1] = vdupq_n_s32(0);
74       temp_sum.val[2] = vdupq_n_s32(0);
75       temp_sum.val[3] = vdupq_n_s32(0);
76       for (int in_h = 0; in_h < input_height; ++in_h) {
77         for (int in_w = 0; in_w < input_width; ++in_w) {
78           const uint8_t* input_data_ptr =
79               input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
80           uint8x16_t input_data_val = vld1q_u8(input_data_ptr);
81 
82           int16x8_t input_data_low_shift =
83               vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_data_val)));
84           int16x8_t input_data_high_shift =
85               vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_data_val)));
86 
87           int32x4_t input_low_low =
88               vmovl_s16(vget_low_s16(input_data_low_shift));
89           int32x4_t input_high_low =
90               vmovl_s16(vget_high_s16(input_data_low_shift));
91           int32x4_t input_low_high =
92               vmovl_s16(vget_low_s16(input_data_high_shift));
93           int32x4_t input_high_high =
94               vmovl_s16(vget_high_s16(input_data_high_shift));
95 
96           temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low);
97           temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low);
98           temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high);
99           temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high);
100         }
101       }
102 
103       temp_sum =
104           MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift);
105 
106       temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
107       temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
108       temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup);
109       temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup);
110 
111       temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup);
112       temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup);
113       temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup);
114       temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup);
115 
116       uint16x4_t narrowed_low_low =
117           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[0]));
118       uint16x4_t narrowed_high_low =
119           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[1]));
120       uint16x4_t narrowed_low_high =
121           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[2]));
122       uint16x4_t narrowed_high_high =
123           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3]));
124 
125       uint16x8_t combined_low =
126           vcombine_u16(narrowed_low_low, narrowed_high_low);
127       uint16x8_t combined_high =
128           vcombine_u16(narrowed_low_high, narrowed_high_high);
129 
130       uint8x8_t narrowed_low = vmovn_u16(combined_low);
131       uint8x8_t narrowed_high = vmovn_u16(combined_high);
132 
133       uint8x16_t combined_output = vcombine_u8(narrowed_low, narrowed_high);
134 
135       uint8_t* output_data_ptr =
136           output_data + Offset(output_shape, out_b, 0, 0, out_d);
137       vst1q_u8(output_data_ptr, combined_output);
138     }
139 #endif  // USE_NEON
140 
141     for (; out_d < end_depth; ++out_d) {
142       int acc = 0;
143       for (int in_h = 0; in_h < input_height; ++in_h) {
144         for (int in_w = 0; in_w < input_width; ++in_w) {
145           acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
146         }
147       }
148 
149       acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
150       acc += bias;
151       acc = std::min(std::max(acc, kMinValue), kMaxValue);
152       output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
153           static_cast<uint8_t>(acc);
154     }
155   }
156 }
157 
158 struct MeanWorkerTask : cpu_backend_threadpool::Task {
MeanWorkerTaskMeanWorkerTask159   MeanWorkerTask(const tflite::MeanParams& op_params,
160                  const RuntimeShape& input_shape, const uint8_t* input_data,
161                  int32 multiplier, int32 shift, int32 bias,
162                  const RuntimeShape& output_shape, uint8_t* output_data,
163                  int start_height, int end_height)
164       : op_params(op_params),
165         input_shape(input_shape),
166         input_data(input_data),
167         multiplier(multiplier),
168         shift(shift),
169         bias(bias),
170         output_shape(output_shape),
171         output_data(output_data),
172         start_height(start_height),
173         end_height(end_height) {}
174 
RunMeanWorkerTask175   void Run() override {
176     MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
177              output_shape, output_data, start_height, end_height);
178   }
179 
180  private:
181   const tflite::MeanParams& op_params;
182   const RuntimeShape& input_shape;
183   const uint8_t* input_data;
184   int32 multiplier;
185   int32 shift;
186   int32 bias;
187   const RuntimeShape& output_shape;
188   uint8_t* output_data;
189   int start_height;
190   int end_height;
191 };
192 
Mean(const tflite::MeanParams & op_params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & unextended_output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,CpuBackendContext * cpu_backend_context)193 inline void Mean(const tflite::MeanParams& op_params,
194                  const RuntimeShape& unextended_input_shape,
195                  const uint8_t* input_data, int32 input_zero_point,
196                  float input_scale, const RuntimeShape& unextended_output_shape,
197                  uint8_t* output_data, int32 output_zero_point,
198                  float output_scale, CpuBackendContext* cpu_backend_context) {
199   ruy::profiler::ScopeLabel label("Mean4D/Uint8");
200   // Current implementation only supports dimension equals 4 and simultaneous
201   // reduction over width and height.
202   TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
203   TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
204   const RuntimeShape input_shape =
205       RuntimeShape::ExtendedShape(4, unextended_input_shape);
206   const RuntimeShape output_shape =
207       RuntimeShape::ExtendedShape(4, unextended_output_shape);
208   const int output_height = output_shape.Dims(1);
209   const int output_width = output_shape.Dims(2);
210   const int output_depth = output_shape.Dims(3);
211 
212   TFLITE_CHECK_EQ(op_params.axis_count, 2);
213   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
214                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
215   TFLITE_CHECK_EQ(output_height, 1);
216   TFLITE_CHECK_EQ(output_width, 1);
217 
218   const int input_height = input_shape.Dims(1);
219   const int input_width = input_shape.Dims(2);
220   const float num_elements_in_axis = input_width * input_height;
221 
222   float temp = input_zero_point * input_scale / output_scale;
223   temp = temp > 0 ? temp + 0.5f : temp - 0.5f;
224   int32_t bias = output_zero_point - static_cast<int32_t>(temp);
225   float real_scale = input_scale / (num_elements_in_axis * output_scale);
226 
227   int32 multiplier, shift;
228   QuantizeMultiplier(real_scale, &multiplier, &shift);
229 
230   constexpr int kMinDepthPerThread = 8;
231   int thread_count = output_depth / kMinDepthPerThread;
232   thread_count = thread_count > 0 ? thread_count : 1;
233   const int capped_thread_count =
234       std::min(thread_count, cpu_backend_context->max_num_threads());
235 
236   if (capped_thread_count == 1) {
237     MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
238              output_shape, output_data, 0, output_depth);
239   } else {
240     // Instead parallel for batch, we loop for the output_depth since batch
241     // is typical 1.
242     std::vector<MeanWorkerTask> tasks;
243     // TODO(b/131746020) don't create new heap allocations every time.
244     // At least we make it a single heap allocation by using reserve().
245     tasks.reserve(capped_thread_count);
246     int depth_start = 0;
247     for (int i = 0; i < capped_thread_count; ++i) {
248       // Try to distribute the tasks as even as possible.
249       int depth_end = depth_start +
250                       (output_depth - depth_start) / (capped_thread_count - i);
251       tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift,
252                          bias, output_shape, output_data, depth_start,
253                          depth_end);
254       depth_start = depth_end;
255     }
256     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
257                                     cpu_backend_context);
258   }
259 }
260 
261 template <typename T>
262 struct SumOp {
operatorSumOp263   inline T operator()(const T& a) const { return a; }
operatorSumOp264   inline T operator()(const T& a, const T& b) const { return a + b; }
265   static constexpr T kNeutralElement = T(0);
266 };
267 
268 template <typename T, typename U>
269 struct CastSumOp {
operatorCastSumOp270   inline U operator()(const T& a) const { return static_cast<U>(a); }
operatorCastSumOp271   inline U operator()(const U& a, const T& b) const {
272     return a + static_cast<U>(b);
273   }
274   static constexpr U kNeutralElement = U(0);
275 };
276 
277 template <typename T>
278 struct ProdOp {
operatorProdOp279   inline T operator()(const T& a) const { return a; }
operatorProdOp280   inline T operator()(const T& a, const T& b) const { return a * b; }
281   static constexpr T kNeutralElement = T(1);
282 };
283 
284 template <typename T>
285 struct MaxOp {
operatorMaxOp286   inline T operator()(const T& a) const { return a; }
operatorMaxOp287   inline T operator()(const T& a, const T& b) const { return (a > b) ? a : b; }
288   static constexpr T kNeutralElement = std::numeric_limits<T>::lowest();
289 };
290 
291 template <typename T>
292 struct MinOp {
operatorMinOp293   inline T operator()(const T& a) const { return a; }
operatorMinOp294   inline T operator()(const T& a, const T& b) const { return (a < b) ? a : b; }
295   static constexpr T kNeutralElement = std::numeric_limits<T>::max();
296 };
297 
298 struct AndOp {
operatorAndOp299   inline bool operator()(bool a) const { return a; }
operatorAndOp300   inline bool operator()(bool a, bool b) const { return a && b; }
301   static constexpr bool kNeutralElement = true;
302 };
303 
304 struct OrOp {
operatorOrOp305   inline bool operator()(bool a) const { return a; }
operatorOrOp306   inline bool operator()(bool a, bool b) const { return a || b; }
307   static constexpr bool kNeutralElement = false;
308 };
309 
310 // When the number of axis is zero, the reduction is simply a copy.
311 template <typename T>
ReduceIsCopy(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data)312 void ReduceIsCopy(const T* input_data, const int* input_dims,
313                   const int input_num_dims, T* output_data) {
314   int num_elems = NumElements(input_dims, input_num_dims);
315   memcpy(output_data, input_data, num_elems * sizeof(T));
316 }
317 
318 // Reduces the input over either odd or even dimensions using Op.
319 // One recursive call for each dimension is made.
320 // 'depth' is the depth of recursion.
321 // 'parity' indicates whether odd or even dimensions are being reduced.
322 // ReducerFirst is applied to the first element to be written to each output
323 // position.
324 // ReducerNext is applied to each subsequent element to be written to each
325 // output position.
326 template <typename T, typename U, typename ReducerFirst, typename ReducerNext>
ReduceImpl(const T * input_data,const int * input_dims,U * output_data,int depth,int parity,bool next,const ReducerFirst & reducer_first,const ReducerNext & reducer_next)327 inline std::pair<const T*, U*> ReduceImpl(const T* input_data,
328                                           const int* input_dims, U* output_data,
329                                           int depth, int parity, bool next,
330                                           const ReducerFirst& reducer_first,
331                                           const ReducerNext& reducer_next) {
332   // The output pointer is incremented conditionally depending on whether the
333   // odd or even dimension is being reduced.
334   // The input pointer is always incremented as each input is read once.
335   if (depth > 0) {
336     U* future_output = output_data;
337     bool update_output = (depth % 2) == parity;
338     for (int i = 0; i < input_dims[0]; ++i) {
339       if (i > 0 && !update_output) {
340         next = true;
341       }
342       std::tie(input_data, future_output) =
343           ReduceImpl(input_data, &input_dims[1], output_data, depth - 1, parity,
344                      next, reducer_first, reducer_next);
345       if (update_output) {
346         output_data = future_output;
347       }
348     }
349     output_data = future_output;
350   } else {
351     // Reduce the final dimension.
352     if (parity) {
353       // Reduce the even dimension. The entire dimension is reduced into one
354       // value.
355       U res = next ? reducer_next(*output_data, *input_data++)
356                    : reducer_first(*input_data++);
357       for (int i = 1; i < input_dims[0]; ++i) {
358         res = reducer_next(res, *input_data++);
359       }
360       *output_data++ = res;
361     } else {
362       // Reduce the odd dimension. Each input is accumulated into a separate
363       // output.
364       if (!next) {
365         for (int i = 0; i < input_dims[0]; ++i) {
366           U res = reducer_first(*input_data++);
367           *output_data++ = res;
368         }
369       } else {
370         for (int i = 0; i < input_dims[0]; ++i) {
371           U res = *output_data;
372           res = reducer_next(res, *input_data++);
373           *output_data++ = res;
374         }
375       }
376     }
377   }
378   return {input_data, output_data};
379 }
380 
381 // A generic reduce method that can be used for reduce_sum, reduce_mean, etc.
382 // This method iterates through input data and reduce elements along the
383 // dimensions given in axis. ReducerFirst is used the first time each output
384 // element is written and ReducerNext is used for all subsequent writes.
385 template <typename In, typename Out, typename ReducerFirst,
386           typename ReducerNext>
Reduce(const In * input_data,const int * input_dims,const int input_num_dims,const int * axis,const int num_axis,Out * output_data,const ReducerFirst & reducer_first,const ReducerNext & reducer_next)387 inline bool Reduce(const In* input_data, const int* input_dims,
388                    const int input_num_dims, const int* axis,
389                    const int num_axis, Out* output_data,
390                    const ReducerFirst& reducer_first,
391                    const ReducerNext& reducer_next) {
392   const int parity = (axis[num_axis - 1] == input_num_dims - 1) ? 1 : 0;
393   ReduceImpl(input_data, input_dims, output_data, input_num_dims - 1, parity,
394              /*next=*/false, reducer_first, reducer_next);
395   return true;
396 }
397 
398 // Computes the mean or sum of elements across dimensions given in axis.
399 // It does so in two stages, first calculates the sum of elements along the axis
400 // then divides it by the number of element in axis for quantized values.
401 template <typename T, typename U>
QuantizedMeanOrSum(const T * input_data,int32_t input_zero_point,float input_scale,const int * input_dims,const int input_num_dims,T * output_data,int32_t output_zero_point,float output_scale,const int * output_dims,const int output_num_dims,const int * axis,const int num_axis_dimensions,bool keep_dims,int * normalized_dims,int * resolved_axis,U * temp_sum,bool compute_sum)402 bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
403                         float input_scale, const int* input_dims,
404                         const int input_num_dims, T* output_data,
405                         int32_t output_zero_point, float output_scale,
406                         const int* output_dims, const int output_num_dims,
407                         const int* axis, const int num_axis_dimensions,
408                         bool keep_dims, int* normalized_dims,
409                         int* resolved_axis, U* temp_sum, bool compute_sum) {
410   ruy::profiler::ScopeLabel label(compute_sum ? "QuantizedSum"
411                                               : "QuantizedMean");
412   // Reset output data.
413   size_t num_outputs = 1;
414   for (int idx = 0; idx < output_num_dims; ++idx) {
415     size_t current = static_cast<size_t>(output_dims[idx]);
416     // Overflow prevention.
417     if (num_outputs > std::numeric_limits<size_t>::max() / current) {
418       return false;
419     }
420     num_outputs *= current;
421   }
422 
423   // Return early when input shape has zero dim. This is done after initializing
424   // data for output tensor because there are cases that the input tensor is
425   // empty but output tensor is not. In that case, output tensor should be
426   // filled with init_value.
427   for (int i = 0; i < input_num_dims; ++i) {
428     if (input_dims[i] == 0) return true;
429   }
430 
431   // Resolve axis.
432   int num_resolved_axis = 0;
433   int normalized_num_dims = 0;
434   if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
435                                  resolved_axis, num_resolved_axis, input_dims,
436                                  normalized_dims, normalized_num_dims)) {
437     return false;
438   }
439 
440   if (num_resolved_axis == 0) {
441     int count = NumElements(input_dims, input_num_dims);
442     for (int i = 0; i < count; ++i) {
443       temp_sum[i] = U(input_data[i]);
444     }
445   } else {
446     if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
447             input_data, normalized_dims, normalized_num_dims, resolved_axis,
448             num_resolved_axis, temp_sum, CastSumOp<T, U>(),
449             CastSumOp<T, U>())) {
450       return false;
451     }
452   }
453 
454   // Calculate mean by dividing output_data by num of aggregated element.
455   size_t num_elements_in_axis = 1;
456   for (int idx = 0; idx < num_resolved_axis; ++idx) {
457     size_t current = static_cast<size_t>(normalized_dims[resolved_axis[idx]]);
458     // Overflow prevention.
459     if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
460       return false;
461     }
462     num_elements_in_axis *= current;
463   }
464 
465   if (num_elements_in_axis > 0) {
466     const float scale = input_scale / output_scale;
467     if (compute_sum) {
468       const float bias = -input_zero_point * scale * num_elements_in_axis;
469       for (size_t idx = 0; idx < num_outputs; ++idx) {
470         const U value =
471             static_cast<U>(TfLiteRound(temp_sum[idx] * scale + bias)) +
472             output_zero_point;
473         output_data[idx] = static_cast<T>(value);
474       }
475     } else {
476       const float bias = -input_zero_point * scale;
477       for (size_t idx = 0; idx < num_outputs; ++idx) {
478         float float_mean = static_cast<float>(temp_sum[idx]) /
479                            static_cast<float>(num_elements_in_axis);
480         float result = TfLiteMin(
481             TfLiteRound(float_mean * scale + bias) + output_zero_point,
482             static_cast<float>(std::numeric_limits<T>::max()));
483         result = TfLiteMax(result,
484                            static_cast<float>(std::numeric_limits<T>::min()));
485         output_data[idx] = static_cast<T>(result);
486       }
487     }
488   }
489   return true;
490 }
491 
492 using ops::builtin::reduce::ReduceType;
493 
494 template <typename T>
ReduceDispatcher(const T * input_data,const int * input_dims,const int input_num_dims,const int * output_dims,int output_num_dims,T * output_data,const int * axis,const int64_t num_axis_dimensions,ReduceType reduce_type)495 inline bool ReduceDispatcher(const T* input_data, const int* input_dims,
496                              const int input_num_dims, const int* output_dims,
497                              int output_num_dims, T* output_data,
498                              const int* axis, const int64_t num_axis_dimensions,
499                              ReduceType reduce_type) {
500   T init_value;
501   switch (reduce_type) {
502     case ReduceType::kProd:
503       init_value = ProdOp<T>::kNeutralElement;
504       break;
505     case ReduceType::kSum:
506       init_value = SumOp<T>::kNeutralElement;
507       break;
508     case ReduceType::kMin:
509       init_value = MinOp<T>::kNeutralElement;
510       break;
511     case ReduceType::kMax:
512       init_value = MaxOp<T>::kNeutralElement;
513       break;
514     default:
515       return false;
516   }
517   // Return early when input shape has zero dim. This is done after initializing
518   // data for output tensor because there are cases that the input tensor is
519   // empty but output tensor is not. In that case, output tensor should be
520   // filled with Op::kNeutralElement.
521   for (int i = 0; i < input_num_dims; ++i) {
522     if (input_dims[i] == 0) {
523       return reference_ops::InitTensorDataForReduce(
524           output_dims, output_num_dims, init_value, output_data);
525     }
526   }
527 
528   switch (reduce_type) {
529     case ReduceType::kProd:
530       return Reduce<T, T, ProdOp<T>, ProdOp<T>>(
531           input_data, input_dims, input_num_dims, axis, num_axis_dimensions,
532           output_data, ProdOp<T>(), ProdOp<T>());
533     case ReduceType::kSum:
534       return Reduce<T, T, SumOp<T>, SumOp<T>>(
535           input_data, input_dims, input_num_dims, axis, num_axis_dimensions,
536           output_data, SumOp<T>(), SumOp<T>());
537     case ReduceType::kMin:
538       return Reduce<T, T, MinOp<T>, MinOp<T>>(
539           input_data, input_dims, input_num_dims, axis, num_axis_dimensions,
540           output_data, MinOp<T>(), MinOp<T>());
541     case ReduceType::kMax:
542       return Reduce<T, T, MaxOp<T>, MaxOp<T>>(
543           input_data, input_dims, input_num_dims, axis, num_axis_dimensions,
544           output_data, MaxOp<T>(), MaxOp<T>());
545     default:
546       return false;
547   }
548 }
549 
550 template <>
551 inline bool ReduceDispatcher<bool>(const bool* input_data,
552                                    const int* input_dims,
553                                    const int input_num_dims,
554                                    const int* output_dims, int output_num_dims,
555                                    bool* output_data, const int* axis,
556                                    const int64_t num_axis_dimensions,
557                                    ReduceType reduce_type) {
558   bool init_value;
559   switch (reduce_type) {
560     case ReduceType::kAny:
561       init_value = OrOp::kNeutralElement;
562       break;
563     case ReduceType::kAll:
564       init_value = AndOp::kNeutralElement;
565       break;
566     default:
567       return false;
568   }
569   // Return early when input shape has zero dim. This is done after initializing
570   // data for output tensor because there are cases that the input tensor is
571   // empty but output tensor is not. In that case, output tensor should be
572   // filled with Op::kNeutralElement.
573   for (int i = 0; i < input_num_dims; ++i) {
574     if (input_dims[i] == 0) {
575       return reference_ops::InitTensorDataForReduce(
576           output_dims, output_num_dims, init_value, output_data);
577     }
578   }
579   switch (reduce_type) {
580     case ReduceType::kAll:
581       return Reduce<bool, bool, AndOp, AndOp>(
582           input_data, input_dims, input_num_dims, axis, num_axis_dimensions,
583           output_data, AndOp(), AndOp());
584     case ReduceType::kAny:
585       return Reduce<bool, bool, OrOp, OrOp>(
586           input_data, input_dims, input_num_dims, axis, num_axis_dimensions,
587           output_data, OrOp(), OrOp());
588     default:
589       return false;
590   }
591 }
592 
593 // Calculate the reduced product by rescaling each multiplication step to
594 // avoid an overflow.
595 template <typename T>
596 struct ReducerFirst {
ReducerFirstReducerFirst597   explicit ReducerFirst(int input_zero_point_arg)
598       : input_zero_point(input_zero_point_arg) {}
operatorReducerFirst599   int32_t operator()(T in) const { return in - input_zero_point; }
600   int input_zero_point;
601 };
602 
603 template <typename T>
604 struct ReducerNext {
ReducerNextReducerNext605   ReducerNext(int32_t input_zero_point_arg, int32_t scaling_multiplier_arg,
606               int32_t scaling_shift_arg)
607       : input_zero_point(input_zero_point_arg),
608         scaling_multiplier(scaling_multiplier_arg),
609         scaling_shift(scaling_shift_arg) {}
operatorReducerNext610   int32_t operator()(int32_t current, T in) const {
611     const int64_t result =
612         static_cast<int64_t>(current) * (in - input_zero_point);
613     return MultiplyByQuantizedMultiplier(result, scaling_multiplier,
614                                          scaling_shift);
615   }
616   int32_t input_zero_point, scaling_multiplier, scaling_shift;
617 };
618 
619 template <typename T>
QuantizedReduceProd(const T * input_data,int32_t input_zero_point,const RuntimeShape & input_shape,T * output_data,int32_t output_zero_point,const RuntimeShape & output_shape,const int * axis,const int64_t num_axis_dimensions,int * resolved_axis,int * normalized_dims,int32_t * temp_prod,int32_t scaling_multiplier,int scaling_shift)620 inline bool QuantizedReduceProd(
621     const T* input_data, int32_t input_zero_point,
622     const RuntimeShape& input_shape, T* output_data, int32_t output_zero_point,
623     const RuntimeShape& output_shape, const int* axis,
624     const int64_t num_axis_dimensions, int* resolved_axis, int* normalized_dims,
625     int32_t* temp_prod, int32_t scaling_multiplier, int scaling_shift) {
626   const int32_t kMinValue = std::numeric_limits<T>::min();
627   const int32_t kMaxValue = std::numeric_limits<T>::max();
628 
629   // Resolve axis.
630   int num_resolved_axis = 0;
631   int normalized_num_dims = 0;
632   if (!reduce_utils::ResolveAxis(input_shape.DimensionsCount(), axis,
633                                  num_axis_dimensions, resolved_axis,
634                                  num_resolved_axis, input_shape.DimsData(),
635                                  normalized_dims, normalized_num_dims)) {
636     return false;
637   }
638 
639   if (!Reduce<T, int32_t, ReducerFirst<T>, ReducerNext<T>>(
640           input_data, normalized_dims, normalized_num_dims, resolved_axis,
641           num_resolved_axis, temp_prod, ReducerFirst<T>(input_zero_point),
642           ReducerNext<T>(input_zero_point, scaling_multiplier,
643                          scaling_shift))) {
644     return false;
645   }
646 
647   for (int i = 0; i < output_shape.FlatSize(); i++) {
648     int32_t result =
649         MultiplyByQuantizedMultiplier(static_cast<int64_t>(temp_prod[i]),
650                                       scaling_multiplier, scaling_shift) +
651         output_zero_point;
652     result = std::min(std::max(result, kMinValue), kMaxValue);
653     output_data[i] = static_cast<T>(result);
654   }
655 
656   return true;
657 }
658 
659 template <typename T>
Mean(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)660 inline void Mean(const tflite::MeanParams& op_params,
661                  const RuntimeShape& input_shape, const T* input_data,
662                  const RuntimeShape& output_shape, T* output_data) {
663   return reference_ops::Mean(op_params, input_shape, input_data, output_shape,
664                              output_data);
665 }
666 
667 // Computes the mean of elements across dimensions given in axis.
668 // It does so in two stages, first calculates the sum of elements along the axis
669 // then divides it by the number of element in axis.
670 template <typename T, typename U>
MeanGeneral(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data,const int * output_dims,const int output_num_dims,const int * axis,const int num_axis_dimensions,bool keep_dims,int * normalized_dims,int * resolved_axis,U * temp_sum)671 inline bool MeanGeneral(const T* input_data, const int* input_dims,
672                         const int input_num_dims, T* output_data,
673                         const int* output_dims, const int output_num_dims,
674                         const int* axis, const int num_axis_dimensions,
675                         bool keep_dims, int* normalized_dims,
676                         int* resolved_axis, U* temp_sum) {
677   ruy::profiler::ScopeLabel label("Mean");
678   // Resolve axis.
679   int num_resolved_axis = 0;
680   int normalized_num_dims = 0;
681   if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
682                                  resolved_axis, num_resolved_axis, input_dims,
683                                  normalized_dims, normalized_num_dims)) {
684     return false;
685   }
686   if (num_resolved_axis == 0) {
687     optimized_ops::ReduceIsCopy(input_data, input_dims, input_num_dims,
688                                 output_data);
689     return true;
690   }
691   // Reset output data.
692   size_t num_outputs = 1;
693   for (int idx = 0; idx < output_num_dims; ++idx) {
694     size_t current = static_cast<size_t>(output_dims[idx]);
695     // Overflow prevention.
696     if (num_outputs > std::numeric_limits<size_t>::max() / current) {
697       return false;
698     }
699     num_outputs *= current;
700   }
701 
702   if (!Reduce<T, U, CastSumOp<T, U>, CastSumOp<T, U>>(
703           input_data, normalized_dims, normalized_num_dims, resolved_axis,
704           num_resolved_axis, temp_sum, CastSumOp<T, U>(), CastSumOp<T, U>())) {
705     return false;
706   }
707 
708   // Calculate mean by dividing output_data by num of aggregated element.
709   size_t num_elements_in_axis = 1;
710   for (int idx = 0; idx < num_resolved_axis; ++idx) {
711     size_t current = static_cast<size_t>(normalized_dims[resolved_axis[idx]]);
712     // Overflow prevention.
713     if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
714       return false;
715     }
716     num_elements_in_axis *= current;
717   }
718 
719   if (num_elements_in_axis > 0) {
720     for (size_t idx = 0; idx < num_outputs; ++idx) {
721       output_data[idx] =
722           static_cast<T>(temp_sum[idx] / static_cast<U>(num_elements_in_axis));
723     }
724   }
725   return true;
726 }
727 
728 template <typename T, typename U>
Mean(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data,const int * output_dims,const int output_num_dims,const int * axis,const int num_axis_dimensions,bool keep_dims,int * normalized_dims,int * resolved_axis,U * temp_sum)729 inline bool Mean(const T* input_data, const int* input_dims,
730                  const int input_num_dims, T* output_data,
731                  const int* output_dims, const int output_num_dims,
732                  const int* axis, const int num_axis_dimensions, bool keep_dims,
733                  int* normalized_dims, int* resolved_axis, U* temp_sum) {
734   return MeanGeneral(input_data, input_dims, input_num_dims, output_data,
735                      output_dims, output_num_dims, axis, num_axis_dimensions,
736                      false, normalized_dims, resolved_axis, temp_sum);
737 }
738 
739 // Use Eigen when Mean is calculated over the last dimension only of a float
740 // tensor.
741 template <>
742 inline bool Mean<float, float>(const float* input_data, const int* input_dims,
743                                const int input_num_dims, float* output_data,
744                                const int* output_dims,
745                                const int output_num_dims, const int* axis,
746                                const int num_axis_dimensions, bool keep_dims,
747                                int* normalized_dims, int* resolved_axis,
748                                float* temp_sum) {
749   // Handle reduce_mean for the last dimensions.
750   int num_resolved_axis = 0;
751   int normalized_num_dims = 0;
752   if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
753                                  resolved_axis, num_resolved_axis, input_dims,
754                                  normalized_dims, normalized_num_dims)) {
755     return false;
756   }
757   if (normalized_num_dims > 1 && num_resolved_axis == 1 &&
758       resolved_axis[0] == (normalized_num_dims - 1)) {
759     ruy::profiler::ScopeLabel label("MeanLastDim/Float");
760     int output_size = normalized_dims[0];
761     const int last_input_dim = normalized_dims[1];
762 
763     // TODO(b/152563685): Consider use eigen to cover more general cases.
764     const MatrixMap<const float> in_mat(input_data, last_input_dim,
765                                         output_size);
766     VectorMap<float> out(output_data, output_size, 1);
767     out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
768     return true;
769   }
770 
771   return MeanGeneral(input_data, input_dims, input_num_dims, output_data,
772                      output_dims, output_num_dims, axis, num_axis_dimensions,
773                      false, normalized_dims, resolved_axis, temp_sum);
774 }
775 
776 // Computes the generic value (i.e., sum/max/min/prod) of elements across
777 // dimensions given in axis. It needs to pass in init_value and reducer.
778 template <typename T>
ReduceGeneric(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data,const int * output_dims,const int output_num_dims,const int * axis,const int64_t num_axis_dimensions,int * resolved_axis,int * normalized_dims,ReduceType reduce_type)779 inline bool ReduceGeneric(const T* input_data, const int* input_dims,
780                           const int input_num_dims, T* output_data,
781                           const int* output_dims, const int output_num_dims,
782                           const int* axis, const int64_t num_axis_dimensions,
783                           int* resolved_axis, int* normalized_dims,
784                           ReduceType reduce_type) {
785   int num_resolved_axis = 0;
786   int normalized_num_dims = 0;
787   if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions,
788                                  resolved_axis, num_resolved_axis, input_dims,
789                                  normalized_dims, normalized_num_dims)) {
790     return false;
791   }
792   if (num_resolved_axis == 0) {
793     optimized_ops::ReduceIsCopy(input_data, input_dims, input_num_dims,
794                                 output_data);
795     return true;
796   }
797   return ReduceDispatcher(input_data, normalized_dims, normalized_num_dims,
798                           output_dims, output_num_dims, output_data,
799                           resolved_axis, num_resolved_axis, reduce_type);
800 }
801 
802 }  // namespace optimized_ops
803 }  // namespace tflite
804 
805 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
806