• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17 
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <cstdint>
25 #include <limits>
26 #include <memory>
27 #include <tuple>
28 #include <type_traits>
29 
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/reference/add.h"
33 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
34 
35 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
36 #include <Accelerate/Accelerate.h>
37 #endif
38 
39 #include "Eigen/Core"
40 #include "fixedpoint/fixedpoint.h"
41 #include "ruy/profiler/instrumentation.h"  // from @ruy
42 #include "tensorflow/lite/c/common.h"
43 #include "tensorflow/lite/kernels/cpu_backend_context.h"
44 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
45 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
46 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
47 #include "tensorflow/lite/kernels/internal/cppmath.h"
48 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
49 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
50 #include "tensorflow/lite/kernels/internal/quantization_util.h"
51 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
52 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
53 #include "tensorflow/lite/kernels/internal/tensor.h"
54 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
55 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
56 #include "tensorflow/lite/kernels/internal/types.h"
57 #include "unsupported/Eigen/CXX11/Tensor"
58 
59 #if __aarch64__ && __clang__
60 #define TFLITE_SOFTMAX_USE_UINT16_LUT
61 #endif
62 
63 namespace tflite {
64 namespace optimized_ops {
65 
66 // Unoptimized reference ops:
67 using reference_ops::Broadcast4DSlowGreater;
68 using reference_ops::Broadcast4DSlowGreaterEqual;
69 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
70 using reference_ops::Broadcast4DSlowGreaterWithScaling;
71 using reference_ops::Broadcast4DSlowLess;
72 using reference_ops::Broadcast4DSlowLessEqual;
73 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
74 using reference_ops::Broadcast4DSlowLessWithScaling;
75 using reference_ops::BroadcastAdd4DSlow;
76 using reference_ops::BroadcastMul4DSlow;
77 using reference_ops::BroadcastSub16POTSlow;
78 using reference_ops::BroadcastSubSlow;
79 using reference_ops::Concatenation;
80 using reference_ops::ConcatenationWithScaling;
81 using reference_ops::DepthConcatenation;
82 using reference_ops::Div;
83 using reference_ops::Elu;
84 using reference_ops::FakeQuant;
85 using reference_ops::Fill;
86 using reference_ops::Gather;
87 using reference_ops::Greater;
88 using reference_ops::GreaterEqual;
89 using reference_ops::GreaterEqualWithScaling;
90 using reference_ops::GreaterWithScaling;
91 using reference_ops::LeakyRelu;
92 using reference_ops::Less;
93 using reference_ops::LessEqual;
94 using reference_ops::LessEqualWithScaling;
95 using reference_ops::LessWithScaling;
96 using reference_ops::Mean;
97 using reference_ops::ProcessBroadcastShapes;
98 using reference_ops::RankOneSelect;
99 using reference_ops::Relu1;
100 using reference_ops::Relu6;
101 using reference_ops::ReluX;
102 using reference_ops::Round;
103 using reference_ops::Select;
104 using reference_ops::SpaceToBatchND;
105 using reference_ops::Split;
106 using reference_ops::Sub16;
107 
108 // TODO(b/80247582) Remove this constant.
109 // This will be phased out as the shifts are revised with more thought. Use of a
110 // constant enables us to track progress on this work.
111 //
112 // Used to convert from old-style shifts (right) to new-style (left).
113 static constexpr int kReverseShift = -1;
114 
115 // Make a local VectorMap typedef allowing to map a float array
116 // as a Eigen vector expression. The std::conditional here is to
117 // construct the suitable Eigen type for the constness of the
118 // data. Indeed, for const data, we need to produce
119 //    Eigen::Map<const Eigen::Matrix<float, ...>>
120 // and not the more straightforward
121 //    Eigen::Map<Eigen::Matrix<const float, ...>>
122 template <typename Scalar>
123 using VectorMap = typename std::conditional<
124     std::is_const<Scalar>::value,
125     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
126                                    Eigen::Dynamic, 1>>,
127     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
128 
129 template <typename Scalar>
MapAsVector(Scalar * data,const RuntimeShape & shape)130 VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
131   const int size = shape.FlatSize();
132   return VectorMap<Scalar>(data, size, 1);
133 }
134 
135 // Make a local VectorMap typedef allowing to map a float array
136 // as a Eigen matrix expression. The same explanation as for VectorMap
137 // above also applies here.
138 template <typename Scalar>
139 using MatrixMap = typename std::conditional<
140     std::is_const<Scalar>::value,
141     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
142                                    Eigen::Dynamic, Eigen::Dynamic>>,
143     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
144 
145 template <typename Scalar>
MapAsMatrixWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)146 MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
147                                                const RuntimeShape& shape) {
148   const int dims_count = shape.DimensionsCount();
149   const int rows = shape.Dims(dims_count - 1);
150   const int cols = FlatSizeSkipDim(shape, dims_count - 1);
151   return MatrixMap<Scalar>(data, rows, cols);
152 }
153 
154 template <typename Scalar>
MapAsMatrixWithFirstDimAsCols(Scalar * data,const RuntimeShape & shape)155 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
156                                                 const RuntimeShape& shape) {
157   const int cols = shape.Dims(0);
158   const int rows = FlatSizeSkipDim(shape, 0);
159   return MatrixMap<Scalar>(data, rows, cols);
160 }
161 
162 template <typename Scalar>
163 using ArrayMap = typename std::conditional<
164     std::is_const<Scalar>::value,
165     Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
166                                   Eigen::Dynamic, Eigen::Dynamic>>,
167     Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
168 
169 template <typename Scalar>
MapAsArrayWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)170 ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
171                                              const RuntimeShape& shape) {
172   const int dims_count = shape.DimensionsCount();
173   const int rows = shape.Dims(dims_count - 1);
174   const int cols = FlatSizeSkipDim(shape, dims_count - 1);
175   return ArrayMap<Scalar>(data, rows, cols);
176 }
177 
178 // Copied from tensorflow/core/framework/tensor_types.h
179 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
180 struct TTypes {
181   // Rank-1 tensor (vector) of scalar type T.
182   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
183                            Eigen::Aligned>
184       Flat;
185   typedef Eigen::TensorMap<
186       Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
187       UnalignedConstMatrix;
188 };
189 
190 // TODO(b/62193649): this function is only needed as long
191 // as we have the --variable_batch hack.
192 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)193 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
194                                                    const RuntimeShape& shape,
195                                                    int rows) {
196   const int flatsize = shape.FlatSize();
197   TFLITE_DCHECK_EQ(flatsize % rows, 0);
198   const int cols = flatsize / rows;
199   return MatrixMap<Scalar>(data, rows, cols);
200 }
201 
202 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
BinaryBroadcastFiveFold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const T * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const T * unswitched_input2_data,const RuntimeShape & output_shape,T * output_data,ElementwiseF elementwise_f,ScalarBroadcastF scalar_broadcast_f)203 inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
204                                     const RuntimeShape& unswitched_input1_shape,
205                                     const T* unswitched_input1_data,
206                                     const RuntimeShape& unswitched_input2_shape,
207                                     const T* unswitched_input2_data,
208                                     const RuntimeShape& output_shape,
209                                     T* output_data, ElementwiseF elementwise_f,
210                                     ScalarBroadcastF scalar_broadcast_f) {
211   ArithmeticParams switched_params = unswitched_params;
212   switched_params.input1_offset = unswitched_params.input2_offset;
213   switched_params.input1_multiplier = unswitched_params.input2_multiplier;
214   switched_params.input1_shift = unswitched_params.input2_shift;
215   switched_params.input2_offset = unswitched_params.input1_offset;
216   switched_params.input2_multiplier = unswitched_params.input1_multiplier;
217   switched_params.input2_shift = unswitched_params.input1_shift;
218 
219   const bool use_unswitched =
220       unswitched_params.broadcast_category ==
221       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
222 
223   const ArithmeticParams& params =
224       use_unswitched ? unswitched_params : switched_params;
225   const T* input1_data =
226       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
227   const T* input2_data =
228       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
229 
230   // Fivefold nested loops. The second input resets its position for each
231   // iteration of the second loop. The first input resets its position at the
232   // beginning of the fourth loop. The innermost loop is an elementwise add of
233   // sections of the arrays.
234   T* output_data_ptr = output_data;
235   const T* input1_data_ptr = input1_data;
236   const T* input2_data_reset = input2_data;
237   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
238   // between input shapes. y3 for input 1 is always broadcast, and so the
239   // dimension there is 1, whereas optionally y1 might be broadcast for
240   // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
241   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
242   int y0 = params.broadcast_shape[0];
243   int y1 = params.broadcast_shape[1];
244   int y2 = params.broadcast_shape[2];
245   int y3 = params.broadcast_shape[3];
246   int y4 = params.broadcast_shape[4];
247   if (y4 > 1) {
248     // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
249     // dimension.
250     for (int i0 = 0; i0 < y0; ++i0) {
251       const T* input2_data_ptr = nullptr;
252       for (int i1 = 0; i1 < y1; ++i1) {
253         input2_data_ptr = input2_data_reset;
254         for (int i2 = 0; i2 < y2; ++i2) {
255           for (int i3 = 0; i3 < y3; ++i3) {
256             elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
257                           output_data_ptr);
258             input2_data_ptr += y4;
259             output_data_ptr += y4;
260           }
261           // We have broadcast y4 of input1 data y3 times, and now move on.
262           input1_data_ptr += y4;
263         }
264       }
265       // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
266       input2_data_reset = input2_data_ptr;
267     }
268   } else if (input1_data_ptr != nullptr) {
269     // Special case of y4 == 1, in which the innermost loop is a single
270     // element and can be combined with the next (y3) as an inner broadcast.
271     //
272     // Note that this handles the case of pure scalar broadcast when
273     // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
274     // broadcast with batch (as y2 > 1).
275     //
276     // NOTE The process is the same as the above general case except
277     // simplified for y4 == 1 and the loop over y3 is contained within the
278     // AddScalarBroadcast function.
279     for (int i0 = 0; i0 < y0; ++i0) {
280       const T* input2_data_ptr = nullptr;
281       for (int i1 = 0; i1 < y1; ++i1) {
282         input2_data_ptr = input2_data_reset;
283         for (int i2 = 0; i2 < y2; ++i2) {
284           scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
285                              output_data_ptr);
286           input2_data_ptr += y3;
287           output_data_ptr += y3;
288           input1_data_ptr += 1;
289         }
290       }
291       input2_data_reset = input2_data_ptr;
292     }
293   }
294 }
295 
296 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
297 
298 // Looks up each element of <indices> in <table>, returns them in a vector.
aarch64_lookup_vector(const uint8x16x4_t table[4],uint8x16_t indices)299 inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
300                                         uint8x16_t indices) {
301   // Look up in 1st quarter of the table: top 2 bits of indices == 00
302   uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
303   // Look up in 2nd quarter of the table: top 2 bits of indices == 01
304   uint8x16_t output2 =
305       vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
306   // Look up in 3rd quarter of the table: top 2 bits of indices == 10
307   uint8x16_t output3 =
308       vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
309   // Look up in 4th quarter of the table: top 2 bits of indices == 11
310   uint8x16_t output4 =
311       vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
312 
313   // Combine result of the 4 lookups.
314   return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
315 }
316 
317 #endif
318 
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)319 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
320                                              float output_activation_max,
321                                              const RuntimeShape& bias_shape,
322                                              const float* bias_data,
323                                              const RuntimeShape& array_shape,
324                                              float* array_data) {
325   BiasAndClamp(output_activation_min, output_activation_max,
326                bias_shape.FlatSize(), bias_data, array_shape.FlatSize(),
327                array_data);
328 }
329 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * cpu_backend_context)330 inline void FullyConnected(
331     const FullyConnectedParams& params, const RuntimeShape& input_shape,
332     const float* input_data, const RuntimeShape& weights_shape,
333     const float* weights_data, const RuntimeShape& bias_shape,
334     const float* optional_bias_data, const RuntimeShape& output_shape,
335     float* output_data, CpuBackendContext* cpu_backend_context) {
336   ruy::profiler::ScopeLabel label("FullyConnected");
337   const int dims_count = weights_shape.DimensionsCount();
338   const int input_rows = weights_shape.Dims(dims_count - 1);
339   cpu_backend_gemm::MatrixParams<float> rhs_params;
340   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
341   rhs_params.rows = input_rows;
342   rhs_params.cols = input_shape.FlatSize() / input_rows;
343   rhs_params.cache_policy =
344       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
345   TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
346   cpu_backend_gemm::MatrixParams<float> lhs_params;
347   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
348   lhs_params.cols = weights_shape.Dims(dims_count - 1);
349   lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
350   lhs_params.cache_policy =
351       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
352   cpu_backend_gemm::MatrixParams<float> dst_params;
353   dst_params.order = cpu_backend_gemm::Order::kColMajor;
354   dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
355   dst_params.cols =
356       FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
357   cpu_backend_gemm::GemmParams<float, float> gemm_params;
358   gemm_params.bias = optional_bias_data;
359   gemm_params.clamp_min = params.float_activation_min;
360   gemm_params.clamp_max = params.float_activation_max;
361   cpu_backend_gemm::Gemm(lhs_params, weights_data, rhs_params, input_data,
362                          dst_params, output_data, gemm_params,
363                          cpu_backend_context);
364 }
365 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,CpuBackendContext * cpu_backend_context)366 inline void FullyConnected(
367     const FullyConnectedParams& params, const RuntimeShape& input_shape,
368     const uint8* input_data, const RuntimeShape& filter_shape,
369     const uint8* filter_data, const RuntimeShape& bias_shape,
370     const int32* bias_data, const RuntimeShape& output_shape,
371     uint8* output_data, CpuBackendContext* cpu_backend_context) {
372   ruy::profiler::ScopeLabel label("FullyConnected/8bit");
373   const int32 input_offset = params.input_offset;
374   const int32 filter_offset = params.weights_offset;
375   const int32 output_offset = params.output_offset;
376   const int32 output_multiplier = params.output_multiplier;
377   const int output_shift = params.output_shift;
378   const int32 output_activation_min = params.quantized_activation_min;
379   const int32 output_activation_max = params.quantized_activation_max;
380   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
381   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
382   // TODO(b/62193649): This really should be:
383   //     const int batches = ArraySize(output_dims, 1);
384   // but the current --variable_batch hack consists in overwriting the 3rd
385   // dimension with the runtime batch size, as we don't keep track for each
386   // array of which dimension is the batch dimension in it.
387   const int output_dim_count = output_shape.DimensionsCount();
388   const int filter_dim_count = filter_shape.DimensionsCount();
389   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
390   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
391   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
392   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
393   const int output_rows = output_shape.Dims(output_dim_count - 1);
394   TFLITE_DCHECK_EQ(output_rows, filter_rows);
395   if (bias_data) {
396     TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
397   }
398 
399   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
400   lhs_params.rows = filter_rows;
401   lhs_params.cols = filter_cols;
402   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
403   lhs_params.zero_point = -filter_offset;
404   lhs_params.cache_policy =
405       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
406   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
407   rhs_params.rows = filter_cols;
408   rhs_params.cols = batches;
409   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
410   rhs_params.zero_point = -input_offset;
411   rhs_params.cache_policy =
412       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
413   cpu_backend_gemm::MatrixParams<uint8> dst_params;
414   dst_params.rows = filter_rows;
415   dst_params.cols = batches;
416   dst_params.order = cpu_backend_gemm::Order::kColMajor;
417   dst_params.zero_point = output_offset;
418   cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
419   gemm_params.bias = bias_data;
420   gemm_params.clamp_min = output_activation_min;
421   gemm_params.clamp_max = output_activation_max;
422   gemm_params.multiplier_fixedpoint = output_multiplier;
423   gemm_params.multiplier_exponent = output_shift;
424   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
425                          dst_params, output_data, gemm_params,
426                          cpu_backend_context);
427 }
428 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,CpuBackendContext * cpu_backend_context)429 inline void FullyConnected(
430     const FullyConnectedParams& params, const RuntimeShape& input_shape,
431     const uint8* input_data, const RuntimeShape& filter_shape,
432     const uint8* filter_data, const RuntimeShape& bias_shape,
433     const int32* bias_data_int32, const RuntimeShape& output_shape,
434     int16* output_data, CpuBackendContext* cpu_backend_context) {
435   ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
436   const int32 input_offset = params.input_offset;
437   const int32 filter_offset = params.weights_offset;
438   const int32 output_offset = params.output_offset;
439   const int32 output_multiplier = params.output_multiplier;
440   const int output_shift = params.output_shift;
441   const int32 output_activation_min = params.quantized_activation_min;
442   const int32 output_activation_max = params.quantized_activation_max;
443   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
444   TFLITE_DCHECK_EQ(output_offset, 0);
445   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
446   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
447 
448   // TODO(b/62193649): This really should be:
449   //     const int batches = ArraySize(output_dims, 1);
450   // but the current --variable_batch hack consists in overwriting the 3rd
451   // dimension with the runtime batch size, as we don't keep track for each
452   // array of which dimension is the batch dimension in it.
453   const int output_dim_count = output_shape.DimensionsCount();
454   const int filter_dim_count = filter_shape.DimensionsCount();
455   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
456   const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
457                                        output_shape, output_dim_count - 1);
458   const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
459 
460   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
461   lhs_params.rows = output_depth;
462   lhs_params.cols = accum_depth;
463   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
464   lhs_params.zero_point = -filter_offset;
465   lhs_params.cache_policy =
466       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
467   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
468   rhs_params.rows = accum_depth;
469   rhs_params.cols = batches;
470   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
471   rhs_params.zero_point = -input_offset;
472   rhs_params.cache_policy =
473       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
474   cpu_backend_gemm::MatrixParams<int16> dst_params;
475   dst_params.rows = output_depth;
476   dst_params.cols = batches;
477   dst_params.order = cpu_backend_gemm::Order::kColMajor;
478   dst_params.zero_point = 0;
479   cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
480   gemm_params.bias = bias_data_int32;
481   gemm_params.clamp_min = output_activation_min;
482   gemm_params.clamp_max = output_activation_max;
483   gemm_params.multiplier_fixedpoint = output_multiplier;
484   gemm_params.multiplier_exponent = output_shift;
485   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
486                          dst_params, output_data, gemm_params,
487                          cpu_backend_context);
488 }
489 
490 // Internal function doing the actual arithmetic work for
491 // ShuffledFullyConnected.
492 // May be called either directly by it (single-threaded case) or may be used
493 // as the 'task' for worker threads to run (multi-threaded case, see
494 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)495 inline void ShuffledFullyConnectedWorkerImpl(
496     const uint8* shuffled_input_workspace_data,
497     const int8* shuffled_weights_data, int batches, int output_depth,
498     int output_stride, int accum_depth, const int32* bias_data,
499     int32 output_multiplier, int output_shift, int16* output_data) {
500 #if defined USE_NEON
501   const int8* shuffled_weights_ptr = shuffled_weights_data;
502   if (batches == 1) {
503     const int right_shift = output_shift > 0 ? 0 : -output_shift;
504     const int left_shift = output_shift > 0 ? output_shift : 0;
505     for (int c = 0; c < output_depth; c += 4) {
506       // Accumulation loop.
507       int32x4_t row_accum0 = vdupq_n_s32(0);
508       int32x4_t row_accum1 = vdupq_n_s32(0);
509       int32x4_t row_accum2 = vdupq_n_s32(0);
510       int32x4_t row_accum3 = vdupq_n_s32(0);
511       for (int d = 0; d < accum_depth; d += 16) {
512         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
513         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
514         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
515         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
516         shuffled_weights_ptr += 64;
517         int8x16_t input =
518             vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
519         int16x8_t local_accum0 =
520             vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
521         int16x8_t local_accum1 =
522             vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
523         int16x8_t local_accum2 =
524             vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
525         int16x8_t local_accum3 =
526             vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
527         local_accum0 =
528             vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
529         local_accum1 =
530             vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
531         local_accum2 =
532             vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
533         local_accum3 =
534             vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
535         row_accum0 = vpadalq_s16(row_accum0, local_accum0);
536         row_accum1 = vpadalq_s16(row_accum1, local_accum1);
537         row_accum2 = vpadalq_s16(row_accum2, local_accum2);
538         row_accum3 = vpadalq_s16(row_accum3, local_accum3);
539       }
540       // Horizontally reduce accumulators
541       int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
542           pairwise_reduced_acc_2, pairwise_reduced_acc_3;
543       pairwise_reduced_acc_0 =
544           vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
545       pairwise_reduced_acc_1 =
546           vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
547       pairwise_reduced_acc_2 =
548           vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
549       pairwise_reduced_acc_3 =
550           vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
551       const int32x2_t reduced_lo =
552           vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
553       const int32x2_t reduced_hi =
554           vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
555       int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
556       // Add bias values.
557       int32x4_t bias_vec = vld1q_s32(bias_data + c);
558       reduced = vaddq_s32(reduced, bias_vec);
559       reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
560       // Multiply by the fixed-point multiplier.
561       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
562       // Rounding-shift-right.
563       using gemmlowp::RoundingDivideByPOT;
564       reduced = RoundingDivideByPOT(reduced, right_shift);
565       // Narrow values down to 16 bit signed.
566       const int16x4_t res16 = vqmovn_s32(reduced);
567       vst1_s16(output_data + c, res16);
568     }
569   } else if (batches == 4) {
570     const int right_shift = output_shift > 0 ? 0 : -output_shift;
571     const int left_shift = output_shift > 0 ? output_shift : 0;
572     for (int c = 0; c < output_depth; c += 4) {
573       const int8* shuffled_input_ptr =
574           reinterpret_cast<const int8*>(shuffled_input_workspace_data);
575       // Accumulation loop.
576       int32x4_t row_accum00 = vdupq_n_s32(0);
577       int32x4_t row_accum10 = vdupq_n_s32(0);
578       int32x4_t row_accum20 = vdupq_n_s32(0);
579       int32x4_t row_accum30 = vdupq_n_s32(0);
580       int32x4_t row_accum01 = vdupq_n_s32(0);
581       int32x4_t row_accum11 = vdupq_n_s32(0);
582       int32x4_t row_accum21 = vdupq_n_s32(0);
583       int32x4_t row_accum31 = vdupq_n_s32(0);
584       int32x4_t row_accum02 = vdupq_n_s32(0);
585       int32x4_t row_accum12 = vdupq_n_s32(0);
586       int32x4_t row_accum22 = vdupq_n_s32(0);
587       int32x4_t row_accum32 = vdupq_n_s32(0);
588       int32x4_t row_accum03 = vdupq_n_s32(0);
589       int32x4_t row_accum13 = vdupq_n_s32(0);
590       int32x4_t row_accum23 = vdupq_n_s32(0);
591       int32x4_t row_accum33 = vdupq_n_s32(0);
592       for (int d = 0; d < accum_depth; d += 16) {
593         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
594         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
595         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
596         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
597         shuffled_weights_ptr += 64;
598         int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
599         int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
600         int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
601         int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
602         shuffled_input_ptr += 64;
603         int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
604 #define TFLITE_SHUFFLED_FC_ACCUM(B)                                           \
605   local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B));      \
606   local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B));      \
607   local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B));      \
608   local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B));      \
609   local_accum0 =                                                              \
610       vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
611   local_accum1 =                                                              \
612       vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
613   local_accum2 =                                                              \
614       vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
615   local_accum3 =                                                              \
616       vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
617   row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0);                   \
618   row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1);                   \
619   row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2);                   \
620   row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
621 
622         TFLITE_SHUFFLED_FC_ACCUM(0)
623         TFLITE_SHUFFLED_FC_ACCUM(1)
624         TFLITE_SHUFFLED_FC_ACCUM(2)
625         TFLITE_SHUFFLED_FC_ACCUM(3)
626 
627 #undef TFLITE_SHUFFLED_FC_ACCUM
628       }
629       // Horizontally reduce accumulators
630 
631 #define TFLITE_SHUFFLED_FC_STORE(B)                                           \
632   {                                                                           \
633     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,                 \
634         pairwise_reduced_acc_2, pairwise_reduced_acc_3;                       \
635     pairwise_reduced_acc_0 =                                                  \
636         vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
637     pairwise_reduced_acc_1 =                                                  \
638         vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
639     pairwise_reduced_acc_2 =                                                  \
640         vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
641     pairwise_reduced_acc_3 =                                                  \
642         vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
643     const int32x2_t reduced_lo =                                              \
644         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);            \
645     const int32x2_t reduced_hi =                                              \
646         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);            \
647     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);                 \
648     int32x4_t bias_vec = vld1q_s32(bias_data + c);                            \
649     reduced = vaddq_s32(reduced, bias_vec);                                   \
650     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));                    \
651     reduced = vqrdmulhq_n_s32(reduced, output_multiplier);                    \
652     using gemmlowp::RoundingDivideByPOT;                                      \
653     reduced = RoundingDivideByPOT(reduced, right_shift);                      \
654     const int16x4_t res16 = vqmovn_s32(reduced);                              \
655     vst1_s16(output_data + c + B * output_stride, res16);                     \
656   }
657 
658       TFLITE_SHUFFLED_FC_STORE(0);
659       TFLITE_SHUFFLED_FC_STORE(1);
660       TFLITE_SHUFFLED_FC_STORE(2);
661       TFLITE_SHUFFLED_FC_STORE(3);
662 
663 #undef TFLITE_SHUFFLED_FC_STORE
664     }
665   } else {
666     TFLITE_DCHECK(false);
667     return;
668   }
669 #else
670   if (batches == 1) {
671     int16* output_ptr = output_data;
672     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
673     // so that just reinterpreting them as int8 values is equivalent to
674     // subtracting 128 from them, thus implementing for free the subtraction of
675     // the zero_point value 128.
676     const int8* shuffled_weights_ptr =
677         reinterpret_cast<const int8*>(shuffled_weights_data);
678     // Likewise, we preshuffled and pre-xored the input data above.
679     const int8* shuffled_input_data =
680         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
681     for (int c = 0; c < output_depth; c += 4) {
682       // Internal accumulation.
683       // Initialize accumulator with the bias-value.
684       int32 accum[4] = {0};
685       // Accumulation loop.
686       for (int d = 0; d < accum_depth; d += 16) {
687         for (int i = 0; i < 4; i++) {
688           for (int j = 0; j < 16; j++) {
689             int8 input_val = shuffled_input_data[d + j];
690             int8 weights_val = *shuffled_weights_ptr++;
691             accum[i] += weights_val * input_val;
692           }
693         }
694       }
695       for (int i = 0; i < 4; i++) {
696         // Add bias value
697         int acc = accum[i] + bias_data[c + i];
698         // Down-scale the final int32 accumulator to the scale used by our
699         // (16-bit, typically 3 integer bits) fixed-point format. The quantized
700         // multiplier and shift here have been pre-computed offline
701         // (e.g. by toco).
702         acc =
703             MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
704         // Saturate, cast to int16, and store to output array.
705         acc = std::max(acc, -32768);
706         acc = std::min(acc, 32767);
707         output_ptr[c + i] = acc;
708       }
709     }
710   } else if (batches == 4) {
711     int16* output_ptr = output_data;
712     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
713     // so that just reinterpreting them as int8 values is equivalent to
714     // subtracting 128 from them, thus implementing for free the subtraction of
715     // the zero_point value 128.
716     const int8* shuffled_weights_ptr =
717         reinterpret_cast<const int8*>(shuffled_weights_data);
718     // Likewise, we preshuffled and pre-xored the input data above.
719     const int8* shuffled_input_data =
720         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
721     for (int c = 0; c < output_depth; c += 4) {
722       const int8* shuffled_input_ptr = shuffled_input_data;
723       // Accumulation loop.
724       // Internal accumulation.
725       // Initialize accumulator with the bias-value.
726       int32 accum[4][4];
727       for (int i = 0; i < 4; i++) {
728         for (int b = 0; b < 4; b++) {
729           accum[i][b] = 0;
730         }
731       }
732       for (int d = 0; d < accum_depth; d += 16) {
733         for (int i = 0; i < 4; i++) {
734           for (int b = 0; b < 4; b++) {
735             for (int j = 0; j < 16; j++) {
736               int8 input_val = shuffled_input_ptr[16 * b + j];
737               int8 weights_val = shuffled_weights_ptr[16 * i + j];
738               accum[i][b] += weights_val * input_val;
739             }
740           }
741         }
742         shuffled_input_ptr += 64;
743         shuffled_weights_ptr += 64;
744       }
745       for (int i = 0; i < 4; i++) {
746         for (int b = 0; b < 4; b++) {
747           // Add bias value
748           int acc = accum[i][b] + bias_data[c + i];
749           // Down-scale the final int32 accumulator to the scale used by our
750           // (16-bit, typically 3 integer bits) fixed-point format. The
751           // quantized multiplier and shift here have been pre-computed offline
752           // (e.g. by toco).
753           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
754                                               output_shift);
755           // Saturate, cast to int16, and store to output array.
756           acc = std::max(acc, -32768);
757           acc = std::min(acc, 32767);
758           output_ptr[b * output_stride + c + i] = acc;
759         }
760       }
761     }
762   } else {
763     TFLITE_DCHECK(false);
764     return;
765   }
766 #endif
767 }
768 
769 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
770 // to allow using gemmlowp's threadpool.
771 struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask772   ShuffledFullyConnectedWorkerTask(const uint8* input_data,
773                                    const int8* shuffled_weights_data,
774                                    int batches, int output_depth,
775                                    int output_stride, int accum_depth,
776                                    const int32* bias_data,
777                                    int32 output_multiplier, int output_shift,
778                                    int16* output_data)
779       : input_data_(input_data),
780         shuffled_weights_data_(shuffled_weights_data),
781         batches_(batches),
782         output_depth_(output_depth),
783         output_stride_(output_stride),
784         accum_depth_(accum_depth),
785         bias_data_(bias_data),
786         output_multiplier_(output_multiplier),
787         output_shift_(output_shift),
788         output_data_(output_data) {}
789 
RunShuffledFullyConnectedWorkerTask790   void Run() override {
791     ShuffledFullyConnectedWorkerImpl(
792         input_data_, shuffled_weights_data_, batches_, output_depth_,
793         output_stride_, accum_depth_, bias_data_, output_multiplier_,
794         output_shift_, output_data_);
795   }
796 
797   const uint8* input_data_;
798   const int8* shuffled_weights_data_;
799   int batches_;
800   int output_depth_;
801   int output_stride_;
802   int accum_depth_;
803   const int32* bias_data_;
804   int32 output_multiplier_;
805   int output_shift_;
806   int16* output_data_;
807 };
808 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,CpuBackendContext * cpu_backend_context)809 inline void ShuffledFullyConnected(
810     const FullyConnectedParams& params, const RuntimeShape& input_shape,
811     const uint8* input_data, const RuntimeShape& weights_shape,
812     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
813     const int32* bias_data, const RuntimeShape& output_shape,
814     int16* output_data, uint8* shuffled_input_workspace_data,
815     CpuBackendContext* cpu_backend_context) {
816   ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
817   const int32 output_multiplier = params.output_multiplier;
818   const int output_shift = params.output_shift;
819   const int32 output_activation_min = params.quantized_activation_min;
820   const int32 output_activation_max = params.quantized_activation_max;
821   TFLITE_DCHECK_EQ(output_activation_min, -32768);
822   TFLITE_DCHECK_EQ(output_activation_max, 32767);
823   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
824   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
825   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
826   // TODO(b/62193649): This really should be:
827   //     const int batches = ArraySize(output_dims, 1);
828   // but the current --variable_batch hack consists in overwriting the 3rd
829   // dimension with the runtime batch size, as we don't keep track for each
830   // array of which dimension is the batch dimension in it.
831   const int output_dim_count = output_shape.DimensionsCount();
832   const int weights_dim_count = weights_shape.DimensionsCount();
833   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
834   const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
835                                        output_shape, output_dim_count - 1);
836   const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
837   TFLITE_DCHECK((accum_depth % 16) == 0);
838   TFLITE_DCHECK((output_depth % 4) == 0);
839   // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
840   // so that just reinterpreting them as int8 values is equivalent to
841   // subtracting 128 from them, thus implementing for free the subtraction of
842   // the zero_point value 128.
843   const int8* int8_shuffled_weights_data =
844       reinterpret_cast<const int8*>(shuffled_weights_data);
845 
846   // Shuffling and xoring of input activations into the workspace buffer
847   if (batches == 1) {
848 #ifdef USE_NEON
849     const uint8x16_t signbit = vdupq_n_u8(0x80);
850     for (int i = 0; i < accum_depth; i += 16) {
851       uint8x16_t val = vld1q_u8(input_data + i);
852       val = veorq_u8(val, signbit);
853       vst1q_u8(shuffled_input_workspace_data + i, val);
854     }
855 #else
856     for (int i = 0; i < accum_depth; i++) {
857       shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
858     }
859 #endif
860   } else if (batches == 4) {
861     uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
862     int c = 0;
863 #ifdef USE_NEON
864     const uint8x16_t signbit = vdupq_n_u8(0x80);
865     for (c = 0; c < accum_depth; c += 16) {
866       const uint8* src_data_ptr = input_data + c;
867       uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
868       uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
869       uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
870       uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
871       val0 = veorq_u8(val0, signbit);
872       val1 = veorq_u8(val1, signbit);
873       val2 = veorq_u8(val2, signbit);
874       val3 = veorq_u8(val3, signbit);
875       vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
876       vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
877       vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
878       vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
879       shuffled_input_workspace_ptr += 64;
880     }
881 #else
882     for (c = 0; c < accum_depth; c += 16) {
883       for (int b = 0; b < 4; b++) {
884         const uint8* src_data_ptr = input_data + b * accum_depth + c;
885         for (int j = 0; j < 16; j++) {
886           uint8 src_val = *src_data_ptr++;
887           // Flip the sign bit, so that the kernel will only need to
888           // reinterpret these uint8 values as int8, getting for free the
889           // subtraction of the zero_point value 128.
890           uint8 dst_val = src_val ^ 0x80;
891           *shuffled_input_workspace_ptr++ = dst_val;
892         }
893       }
894     }
895 #endif
896   } else {
897     TFLITE_DCHECK(false);
898     return;
899   }
900 
901   static constexpr int kKernelRows = 4;
902   const int thread_count =
903       LegacyHowManyThreads<kKernelRows>(cpu_backend_context->max_num_threads(),
904                                         output_depth, batches, accum_depth);
905   if (thread_count == 1) {
906     // Single-thread case: do the computation on the current thread, don't
907     // use a threadpool
908     ShuffledFullyConnectedWorkerImpl(
909         shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
910         output_depth, output_depth, accum_depth, bias_data, output_multiplier,
911         output_shift, output_data);
912     return;
913   }
914 
915   // Multi-threaded case: use the gemmlowp context's threadpool.
916   TFLITE_DCHECK_GT(thread_count, 1);
917   std::vector<ShuffledFullyConnectedWorkerTask> tasks;
918   // TODO(b/131746020) don't create new heap allocations every time.
919   // At least we make it a single heap allocation by using reserve().
920   tasks.reserve(thread_count);
921   const int kRowsPerWorker =
922       RoundUp<kKernelRows>(CeilQuotient(output_depth, thread_count));
923   int row_start = 0;
924   for (int i = 0; i < thread_count; i++) {
925     int row_end = std::min(output_depth, row_start + kRowsPerWorker);
926     tasks.emplace_back(shuffled_input_workspace_data,
927                        int8_shuffled_weights_data + row_start * accum_depth,
928                        batches, row_end - row_start, output_depth, accum_depth,
929                        bias_data + row_start, output_multiplier, output_shift,
930                        output_data + row_start);
931     row_start = row_end;
932   }
933   TFLITE_DCHECK_EQ(row_start, output_depth);
934   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
935                                   cpu_backend_context);
936 }
937 
938 #ifdef USE_NEON
939 
RoundToNearest(const float32x4_t input)940 inline int32x4_t RoundToNearest(const float32x4_t input) {
941 #if defined(__aarch64__) || defined(__SSSE3__)
942   // Note: vcvtnq_s32_f32 is not available in ARMv7
943   return vcvtnq_s32_f32(input);
944 #else
945   static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
946   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
947   static const float32x4_t minus_point5_val_dup = vdupq_n_f32(-0.5f);
948 
949   const uint32x4_t mask = vcltq_f32(input, zero_val_dup);
950   const float32x4_t round =
951       vbslq_f32(mask, minus_point5_val_dup, point5_val_dup);
952   return vcvtq_s32_f32(vaddq_f32(input, round));
953 #endif  // defined(__aarch64__) || defined(__SSSE3__)
954 }
955 
RoundToNearestUnsigned(const float32x4_t input)956 inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
957 #if defined(__aarch64__)
958   // Note that vcvtnq_u32_f32 is not available in ARMv7 or in arm_neon_sse.h.
959   return vcvtnq_u32_f32(input);
960 #else
961   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
962 
963   return vcvtq_u32_f32(vaddq_f32(input, point5_val_dup));
964 #endif  // defined(__aarch64__)
965 }
966 
967 #endif  // USE_NEON
968 
MeanImpl(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,int32 multiplier,int32 shift,int32 bias,const RuntimeShape & output_shape,uint8_t * output_data,int start_depth,int end_depth)969 inline void MeanImpl(const tflite::MeanParams& op_params,
970                      const RuntimeShape& input_shape, const uint8_t* input_data,
971                      int32 multiplier, int32 shift, int32 bias,
972                      const RuntimeShape& output_shape, uint8_t* output_data,
973                      int start_depth, int end_depth) {
974   ruy::profiler::ScopeLabel label("Mean4D/Uint8/MeanImpl");
975 
976   // Current implementation only supports dimension equals 4 and simultaneous
977   // reduction over width and height.
978   const int output_batch = output_shape.Dims(0);
979   const int output_height = output_shape.Dims(2);
980   const int output_width = output_shape.Dims(2);
981   const int input_height = input_shape.Dims(1);
982   const int input_width = input_shape.Dims(2);
983 
984   TFLITE_CHECK_EQ(op_params.axis_count, 2);
985   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
986                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
987   TFLITE_CHECK_EQ(output_height, 1);
988   TFLITE_CHECK_EQ(output_width, 1);
989 
990   constexpr int32_t kMinValue = std::numeric_limits<uint8_t>::min();
991   constexpr int32_t kMaxValue = std::numeric_limits<uint8_t>::max();
992 
993 #ifdef USE_NEON
994   const int32x4_t bias_dup = vdupq_n_s32(bias);
995   const int32x4_t min_dup = vdupq_n_s32(kMinValue);
996   const int32x4_t max_dup = vdupq_n_s32(kMaxValue);
997 #endif  // USE_NEON
998 
999   for (int out_b = 0; out_b < output_batch; ++out_b) {
1000     int out_d = start_depth;
1001 #ifdef USE_NEON
1002 
1003     for (; out_d <= end_depth - 16; out_d += 16) {
1004       int32x4x4_t temp_sum;
1005       temp_sum.val[0] = vdupq_n_s32(0);
1006       temp_sum.val[1] = vdupq_n_s32(0);
1007       temp_sum.val[2] = vdupq_n_s32(0);
1008       temp_sum.val[3] = vdupq_n_s32(0);
1009       for (int in_h = 0; in_h < input_height; ++in_h) {
1010         for (int in_w = 0; in_w < input_width; ++in_w) {
1011           const uint8_t* input_data_ptr =
1012               input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
1013           uint8x16_t input_data_val = vld1q_u8(input_data_ptr);
1014 
1015           int16x8_t input_data_low_shift =
1016               vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_data_val)));
1017           int16x8_t input_data_high_shift =
1018               vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_data_val)));
1019 
1020           int32x4_t input_low_low =
1021               vmovl_s16(vget_low_s16(input_data_low_shift));
1022           int32x4_t input_high_low =
1023               vmovl_s16(vget_high_s16(input_data_low_shift));
1024           int32x4_t input_low_high =
1025               vmovl_s16(vget_low_s16(input_data_high_shift));
1026           int32x4_t input_high_high =
1027               vmovl_s16(vget_high_s16(input_data_high_shift));
1028 
1029           temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low);
1030           temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low);
1031           temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high);
1032           temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high);
1033         }
1034       }
1035 
1036       temp_sum =
1037           MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift);
1038 
1039       temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
1040       temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
1041       temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup);
1042       temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup);
1043 
1044       temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup);
1045       temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup);
1046       temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup);
1047       temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup);
1048 
1049       uint16x4_t narrowed_low_low =
1050           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[0]));
1051       uint16x4_t narrowed_high_low =
1052           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[1]));
1053       uint16x4_t narrowed_low_high =
1054           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[2]));
1055       uint16x4_t narrowed_high_high =
1056           vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3]));
1057 
1058       uint16x8_t combined_low =
1059           vcombine_u16(narrowed_low_low, narrowed_high_low);
1060       uint16x8_t combined_high =
1061           vcombine_u16(narrowed_low_high, narrowed_high_high);
1062 
1063       uint8x8_t narrowed_low = vmovn_u16(combined_low);
1064       uint8x8_t narrowed_high = vmovn_u16(combined_high);
1065 
1066       uint8x16_t combined_output = vcombine_u8(narrowed_low, narrowed_high);
1067 
1068       uint8_t* output_data_ptr =
1069           output_data + Offset(output_shape, out_b, 0, 0, out_d);
1070       vst1q_u8(output_data_ptr, combined_output);
1071     }
1072 #endif  // USE_NEON
1073 
1074     for (; out_d < end_depth; ++out_d) {
1075       int acc = 0;
1076       for (int in_h = 0; in_h < input_height; ++in_h) {
1077         for (int in_w = 0; in_w < input_width; ++in_w) {
1078           acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
1079         }
1080       }
1081 
1082       acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
1083       acc += bias;
1084       acc = std::min(std::max(acc, kMinValue), kMaxValue);
1085       output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1086           static_cast<uint8_t>(acc);
1087     }
1088   }
1089 }
1090 
1091 struct MeanWorkerTask : cpu_backend_threadpool::Task {
MeanWorkerTaskMeanWorkerTask1092   MeanWorkerTask(const tflite::MeanParams& op_params,
1093                  const RuntimeShape& input_shape, const uint8_t* input_data,
1094                  int32 multiplier, int32 shift, int32 bias,
1095                  const RuntimeShape& output_shape, uint8_t* output_data,
1096                  int start_height, int end_height)
1097       : op_params(op_params),
1098         input_shape(input_shape),
1099         input_data(input_data),
1100         multiplier(multiplier),
1101         shift(shift),
1102         bias(bias),
1103         output_shape(output_shape),
1104         output_data(output_data),
1105         start_height(start_height),
1106         end_height(end_height) {}
1107 
RunMeanWorkerTask1108   void Run() override {
1109     MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1110              output_shape, output_data, start_height, end_height);
1111   }
1112 
1113  private:
1114   const tflite::MeanParams& op_params;
1115   const RuntimeShape& input_shape;
1116   const uint8_t* input_data;
1117   int32 multiplier;
1118   int32 shift;
1119   int32 bias;
1120   const RuntimeShape& output_shape;
1121   uint8_t* output_data;
1122   int start_height;
1123   int end_height;
1124 };
1125 
Mean(const tflite::MeanParams & op_params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & unextended_output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,CpuBackendContext * cpu_backend_context)1126 inline void Mean(const tflite::MeanParams& op_params,
1127                  const RuntimeShape& unextended_input_shape,
1128                  const uint8_t* input_data, int32 input_zero_point,
1129                  float input_scale, const RuntimeShape& unextended_output_shape,
1130                  uint8_t* output_data, int32 output_zero_point,
1131                  float output_scale, CpuBackendContext* cpu_backend_context) {
1132   ruy::profiler::ScopeLabel label("Mean4D/Uint8");
1133   // Current implementation only supports dimension equals 4 and simultaneous
1134   // reduction over width and height.
1135   TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
1136   TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1137   const RuntimeShape input_shape =
1138       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1139   const RuntimeShape output_shape =
1140       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1141   const int output_height = output_shape.Dims(1);
1142   const int output_width = output_shape.Dims(2);
1143   const int output_depth = output_shape.Dims(3);
1144 
1145   TFLITE_CHECK_EQ(op_params.axis_count, 2);
1146   TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1147                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1148   TFLITE_CHECK_EQ(output_height, 1);
1149   TFLITE_CHECK_EQ(output_width, 1);
1150 
1151   const int input_height = input_shape.Dims(1);
1152   const int input_width = input_shape.Dims(2);
1153   const float num_elements_in_axis = input_width * input_height;
1154 
1155   float temp = input_zero_point * input_scale / output_scale;
1156   temp = temp > 0 ? temp + 0.5f : temp - 0.5f;
1157   int32_t bias = output_zero_point - static_cast<int32_t>(temp);
1158   float real_scale = input_scale / (num_elements_in_axis * output_scale);
1159 
1160   int32 multiplier, shift;
1161   QuantizeMultiplier(real_scale, &multiplier, &shift);
1162 
1163   constexpr int kMinDepthPerThread = 8;
1164   int thread_count = output_depth / kMinDepthPerThread;
1165   thread_count = thread_count > 0 ? thread_count : 1;
1166   const int capped_thread_count =
1167       std::min(thread_count, cpu_backend_context->max_num_threads());
1168 
1169   if (capped_thread_count == 1) {
1170     MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1171              output_shape, output_data, 0, output_depth);
1172   } else {
1173     // Instead parallel for batch, we loop for the output_depth since batch
1174     // is typical 1.
1175     std::vector<MeanWorkerTask> tasks;
1176     // TODO(b/131746020) don't create new heap allocations every time.
1177     // At least we make it a single heap allocation by using reserve().
1178     tasks.reserve(capped_thread_count);
1179     int depth_start = 0;
1180     for (int i = 0; i < capped_thread_count; ++i) {
1181       // Try to distribute the tasks as even as possible.
1182       int depth_end = depth_start +
1183                       (output_depth - depth_start) / (capped_thread_count - i);
1184       tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift,
1185                          bias, output_shape, output_data, depth_start,
1186                          depth_end);
1187       depth_start = depth_end;
1188     }
1189     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
1190                                     cpu_backend_context);
1191   }
1192 }
1193 
1194 template <typename T, typename U>
MeanGeneral(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data,const int * output_dims,const int output_num_dims,const int * axis,const int num_axis_dimensions,bool keep_dims,int * temp_index,int * resolved_axis,U * temp_sum)1195 inline bool MeanGeneral(const T* input_data, const int* input_dims,
1196                         const int input_num_dims, T* output_data,
1197                         const int* output_dims, const int output_num_dims,
1198                         const int* axis, const int num_axis_dimensions,
1199                         bool keep_dims, int* temp_index, int* resolved_axis,
1200                         U* temp_sum) {
1201   return reference_ops::Mean(input_data, input_dims, input_num_dims,
1202                              output_data, output_dims, output_num_dims, axis,
1203                              num_axis_dimensions, keep_dims, temp_index,
1204                              resolved_axis, temp_sum);
1205 }
1206 
1207 template <>
1208 inline bool MeanGeneral<float, float>(
1209     const float* input_data, const int* input_dims, const int input_num_dims,
1210     float* output_data, const int* output_dims, const int output_num_dims,
1211     const int* axis, const int num_axis_dimensions, bool keep_dims,
1212     int* temp_index, int* resolved_axis, float* temp_sum) {
1213   // Handle reduce_mean for the last dimensions.
1214   if (num_axis_dimensions == 1 && axis[0] == (input_num_dims - 1)) {
1215     ruy::profiler::ScopeLabel label("MeanLastDim/Float");
1216     int output_size = 1;
1217     for (int i = 0; i < input_num_dims - 1; ++i) {
1218       output_size *= input_dims[i];
1219     }
1220     const int last_input_dim = input_dims[axis[0]];
1221 
1222     // TODO(b/152563685): Consider use eigen to cover more general cases.
1223     const MatrixMap<const float> in_mat(input_data, last_input_dim,
1224                                         output_size);
1225     VectorMap<float> out(output_data, output_size, 1);
1226     out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
1227     return true;
1228   }
1229 
1230   return reference_ops::Mean(input_data, input_dims, input_num_dims,
1231                              output_data, output_dims, output_num_dims, axis,
1232                              num_axis_dimensions, keep_dims, temp_index,
1233                              resolved_axis, temp_sum);
1234 }
1235 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,CpuBackendContext * cpu_backend_context)1236 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1237                  const float* input_data, const RuntimeShape& filter_shape,
1238                  const float* filter_data, const RuntimeShape& bias_shape,
1239                  const float* bias_data, const RuntimeShape& output_shape,
1240                  float* output_data, const RuntimeShape& im2col_shape,
1241                  float* im2col_data, CpuBackendContext* cpu_backend_context) {
1242   const int stride_width = params.stride_width;
1243   const int stride_height = params.stride_height;
1244   const int dilation_width_factor = params.dilation_width_factor;
1245   const int dilation_height_factor = params.dilation_height_factor;
1246   const float output_activation_min = params.float_activation_min;
1247   const float output_activation_max = params.float_activation_max;
1248   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1249   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1250   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1251 
1252   ruy::profiler::ScopeLabel label("Conv");
1253 
1254   // NB: the float 0.0f value is represented by all zero bytes.
1255   const uint8 float_zero_byte = 0x00;
1256   const float* gemm_input_data = nullptr;
1257   const RuntimeShape* gemm_input_shape = nullptr;
1258   const int filter_width = filter_shape.Dims(2);
1259   const int filter_height = filter_shape.Dims(1);
1260   const bool need_dilated_im2col =
1261       dilation_width_factor != 1 || dilation_height_factor != 1;
1262   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1263                            filter_width != 1 || filter_height != 1;
1264   if (need_dilated_im2col) {
1265     DilatedIm2col(params, float_zero_byte, input_shape, input_data,
1266                   filter_shape, output_shape, im2col_data);
1267     gemm_input_data = im2col_data;
1268     gemm_input_shape = &im2col_shape;
1269   } else if (need_im2col) {
1270     TFLITE_DCHECK(im2col_data);
1271     Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
1272            input_data, im2col_shape, im2col_data);
1273     gemm_input_data = im2col_data;
1274     gemm_input_shape = &im2col_shape;
1275   } else {
1276     TFLITE_DCHECK(!im2col_data);
1277     gemm_input_data = input_data;
1278     gemm_input_shape = &input_shape;
1279   }
1280 
1281   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
1282   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
1283   int n = output_shape.Dims(3);
1284   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
1285 
1286 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1287   // The following code computes matrix multiplication c = a * transponse(b)
1288   // with CBLAS, where:
1289   // * `a` is a matrix with dimensions (m, k).
1290   // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
1291   // * `c` is a matrix with dimensions (m, n).
1292   // The naming of variables are aligned with CBLAS specification here.
1293   const float* a = gemm_input_data;
1294   const float* b = filter_data;
1295   float* c = output_data;
1296   // The stride of matrix a, b and c respectively.
1297   int stride_a = k;
1298   int stride_b = k;
1299   int stride_c = n;
1300 
1301   cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
1302               stride_a, b, stride_b, 0.0f, c, stride_c);
1303   optimized_ops::AddBiasAndEvalActivationFunction(
1304       output_activation_min, output_activation_max, bias_shape, bias_data,
1305       output_shape, output_data);
1306 #else
1307   // When an optimized CBLAS implementation is not available, fall back
1308   // to using cpu_backend_gemm.
1309   cpu_backend_gemm::MatrixParams<float> lhs_params;
1310   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1311   lhs_params.rows = n;
1312   lhs_params.cols = k;
1313   cpu_backend_gemm::MatrixParams<float> rhs_params;
1314   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1315   rhs_params.rows = k;
1316   rhs_params.cols = m;
1317   cpu_backend_gemm::MatrixParams<float> dst_params;
1318   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1319   dst_params.rows = n;
1320   dst_params.cols = m;
1321   cpu_backend_gemm::GemmParams<float, float> gemm_params;
1322   gemm_params.bias = bias_data;
1323   gemm_params.clamp_min = output_activation_min;
1324   gemm_params.clamp_max = output_activation_max;
1325   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1326                          dst_params, output_data, gemm_params,
1327                          cpu_backend_context);
1328 #endif  //  defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1329 }
1330 
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & accum_scratch_shape,int32_t * accum_scratch,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,CpuBackendContext * context)1331 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
1332                        const RuntimeShape& input_shape,
1333                        const int8_t* input_data,
1334                        const RuntimeShape& filter_shape,
1335                        const int8_t* filter_data,
1336                        const RuntimeShape& bias_shape, const float* bias_data,
1337                        const RuntimeShape& accum_scratch_shape,
1338                        int32_t* accum_scratch, const RuntimeShape& output_shape,
1339                        float* output_data, const RuntimeShape& im2col_shape,
1340                        int8_t* im2col_data, CpuBackendContext* context) {
1341   const int stride_width = params.stride_width;
1342   const int stride_height = params.stride_height;
1343   const int dilation_width_factor = params.dilation_width_factor;
1344   const int dilation_height_factor = params.dilation_height_factor;
1345   const float output_activation_min = params.float_activation_min;
1346   const float output_activation_max = params.float_activation_max;
1347   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1348   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1349   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1350 
1351   const int batch_size = input_shape.Dims(0);
1352   const int filter_width = filter_shape.Dims(2);
1353   const int filter_height = filter_shape.Dims(1);
1354 
1355   const int input_zero_point = 0;
1356   const int8_t* gemm_input_data = nullptr;
1357   int num_input;
1358   const bool need_dilated_im2col =
1359       dilation_width_factor != 1 || dilation_height_factor != 1;
1360   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1361                            filter_width != 1 || filter_height != 1;
1362 
1363   if (need_dilated_im2col) {
1364     DilatedIm2col(params, input_zero_point, input_shape, input_data,
1365                   filter_shape, output_shape, im2col_data);
1366     gemm_input_data = im2col_data;
1367     num_input = im2col_shape.FlatSize();
1368   } else if (need_im2col) {
1369     TFLITE_DCHECK(im2col_data);
1370     // symmetric quantization assumes zero point of 0.
1371 
1372     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1373            input_data, im2col_shape, im2col_data);
1374     gemm_input_data = im2col_data;
1375     num_input = im2col_shape.FlatSize();
1376   } else {
1377     TFLITE_DCHECK(!im2col_data);
1378     gemm_input_data = input_data;
1379     num_input = input_shape.FlatSize();
1380   }
1381 
1382   // Flatten 4D matrices into 2D matrices for matrix multiplication.
1383 
1384   // Flatten so that each filter has its own row.
1385   const int filter_rows = filter_shape.Dims(0);
1386   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1387 
1388   // In MatrixBatchVectorMultiplyAccumulate, each output value is the
1389   // dot product of one row of the first matrix with one row of the second
1390   // matrix. Therefore, the number of cols in each matrix are equivalent.
1391   //
1392   // After Im2Col, each input patch becomes a row.
1393   const int gemm_input_cols = filter_cols;
1394   const int gemm_input_rows = num_input / gemm_input_cols;
1395 
1396   const int output_cols = output_shape.Dims(3);
1397   const int output_rows = FlatSizeSkipDim(output_shape, 3);
1398   TFLITE_DCHECK_EQ(output_cols, filter_rows);
1399   TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
1400   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
1401 
1402   // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
1403   // input matrix has its own scale factor. This code duplicates the scale
1404   // factors for each row in the same batch.
1405   const int rows_per_batch = gemm_input_rows / batch_size;
1406   for (int i = gemm_input_rows - 1; i >= 0; --i) {
1407     scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
1408   }
1409 
1410   std::fill_n(output_data, output_rows * output_cols, 0.0f);
1411 
1412   // The scratch buffer must have the same size as the output.
1413   TFLITE_DCHECK_EQ(accum_scratch_shape.FlatSize(), output_shape.FlatSize());
1414   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
1415       filter_data, filter_rows, filter_cols, gemm_input_data,
1416       scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
1417       output_data, context);
1418   AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
1419                                    bias_shape, bias_data, output_shape,
1420                                    output_data);
1421 }
1422 
HybridConvPerChannel(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,const float * per_channel_scale,int32_t * input_offset,const RuntimeShape & scratch_shape,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * cpu_backend_context)1423 inline void HybridConvPerChannel(
1424     const ConvParams& params, float* scaling_factors_ptr,
1425     const RuntimeShape& input_shape, const int8_t* input_data,
1426     const RuntimeShape& filter_shape, const int8_t* filter_data,
1427     const RuntimeShape& bias_shape, const float* bias_data,
1428     const RuntimeShape& output_shape, float* output_data,
1429     const RuntimeShape& im2col_shape, int8_t* im2col_data,
1430     const float* per_channel_scale, int32_t* input_offset,
1431     const RuntimeShape& scratch_shape, int32_t* scratch, int32_t* row_sums,
1432     bool* compute_row_sums, CpuBackendContext* cpu_backend_context) {
1433   ruy::profiler::ScopeLabel label("ConvHybridPerChannel");
1434   const int stride_width = params.stride_width;
1435   const int stride_height = params.stride_height;
1436   const int dilation_width_factor = params.dilation_width_factor;
1437   const int dilation_height_factor = params.dilation_height_factor;
1438   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1439   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1440   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1441 
1442   const int8* gemm_input_data = nullptr;
1443   const RuntimeShape* gemm_input_shape = nullptr;
1444   const int filter_width = filter_shape.Dims(2);
1445   const int filter_height = filter_shape.Dims(1);
1446   const bool need_dilated_im2col =
1447       dilation_width_factor != 1 || dilation_height_factor != 1;
1448   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1449                            filter_width != 1 || filter_height != 1;
1450 
1451   const int batch_size = input_shape.Dims(0);
1452 
1453   if (need_dilated_im2col) {
1454     TFLITE_DCHECK(im2col_data);
1455     optimized_ops::DilatedIm2col(params, input_shape, input_data, filter_shape,
1456                                  output_shape, im2col_data, input_offset,
1457                                  batch_size);
1458     gemm_input_data = im2col_data;
1459     gemm_input_shape = &im2col_shape;
1460   } else if (need_im2col) {
1461     Im2col(params, filter_height, filter_width, input_offset, batch_size,
1462            input_shape, input_data, im2col_shape, im2col_data);
1463     gemm_input_data = im2col_data;
1464     gemm_input_shape = &im2col_shape;
1465   } else {
1466     TFLITE_DCHECK(!im2col_data);
1467     gemm_input_data = input_data;
1468     gemm_input_shape = &input_shape;
1469   }
1470 
1471   const int filter_rows = filter_shape.Dims(0);
1472   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1473 
1474   const int gemm_input_rows = gemm_input_shape->Dims(3);
1475   const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1476   const int output_rows = output_shape.Dims(3);
1477   const int output_cols =
1478       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1479 
1480   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1481   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1482   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1483   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1484   TFLITE_DCHECK_EQ(scratch_shape.FlatSize(), output_shape.FlatSize());
1485   if (!compute_row_sums || *compute_row_sums) {
1486     tensor_utils::ReductionSumVector(filter_data, row_sums, filter_rows,
1487                                      filter_cols);
1488     if (compute_row_sums) {
1489       *compute_row_sums = false;
1490     }
1491   }
1492 
1493   cpu_backend_gemm::MatrixParams<int8> lhs_params;
1494   lhs_params.rows = filter_rows;
1495   lhs_params.cols = filter_cols;
1496   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1497 
1498   cpu_backend_gemm::MatrixParams<int8> rhs_params;
1499   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1500   rhs_params.rows = gemm_input_rows;
1501   rhs_params.cols = gemm_input_cols;
1502 
1503   cpu_backend_gemm::MatrixParams<int32> dst_params;
1504   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1505   dst_params.rows = output_rows;
1506   dst_params.cols = output_cols;
1507 
1508   // TODO(b/149003801): Use hybrid gemm once supported in Ruy.
1509   cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
1510   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1511                          dst_params, scratch, gemm_params, cpu_backend_context);
1512 
1513   MatrixMap<float> out_mat(output_data, filter_rows, output_cols);
1514   MatrixMap<int32_t> in_mat(scratch, filter_rows, output_cols);
1515   VectorMap<const float> bias_data_vec(bias_data, filter_rows, 1);
1516   VectorMap<int32_t> row_sums_vec(row_sums, filter_rows, 1);
1517   VectorMap<const float> per_channel_scale_vec(per_channel_scale, filter_rows,
1518                                                1);
1519   const int cols_per_batch = output_cols / batch_size;
1520   for (int c = 0; c < output_cols; c++) {
1521     const int b = c / cols_per_batch;
1522     const float input_scale = scaling_factors_ptr[b];
1523     const int32_t zero_point = input_offset[b];
1524     out_mat.col(c) =
1525         (((in_mat.col(c) - (row_sums_vec * zero_point))
1526               .cast<float>()
1527               .cwiseProduct((per_channel_scale_vec * input_scale))) +
1528          bias_data_vec)
1529             .cwiseMin(params.float_activation_max)
1530             .cwiseMax(params.float_activation_min);
1531   }
1532 }
1533 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,CpuBackendContext * cpu_backend_context)1534 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1535                  const uint8* input_data, const RuntimeShape& filter_shape,
1536                  const uint8* filter_data, const RuntimeShape& bias_shape,
1537                  const int32* bias_data, const RuntimeShape& output_shape,
1538                  uint8* output_data, const RuntimeShape& im2col_shape,
1539                  uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
1540   ruy::profiler::ScopeLabel label("Conv/8bit");
1541 
1542   const int stride_width = params.stride_width;
1543   const int stride_height = params.stride_height;
1544   const int dilation_width_factor = params.dilation_width_factor;
1545   const int dilation_height_factor = params.dilation_height_factor;
1546   const int32 input_offset = params.input_offset;
1547   const int32 filter_offset = params.weights_offset;
1548   const int32 output_offset = params.output_offset;
1549   const int32 output_multiplier = params.output_multiplier;
1550   const int output_shift = params.output_shift;
1551   const int32 output_activation_min = params.quantized_activation_min;
1552   const int32 output_activation_max = params.quantized_activation_max;
1553   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1554   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1555   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1556 
1557   const uint8* gemm_input_data = nullptr;
1558   const RuntimeShape* gemm_input_shape = nullptr;
1559   const int filter_width = filter_shape.Dims(2);
1560   const int filter_height = filter_shape.Dims(1);
1561   const bool need_dilated_im2col =
1562       dilation_width_factor != 1 || dilation_height_factor != 1;
1563   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1564                            filter_width != 1 || filter_height != 1;
1565   if (need_dilated_im2col) {
1566     TFLITE_DCHECK(im2col_data);
1567     const int input_zero_point = -input_offset;
1568     TFLITE_DCHECK_GE(input_zero_point, 0);
1569     TFLITE_DCHECK_LE(input_zero_point, 255);
1570     DilatedIm2col(params, input_zero_point, input_shape, input_data,
1571                   filter_shape, output_shape, im2col_data);
1572     gemm_input_data = im2col_data;
1573     gemm_input_shape = &im2col_shape;
1574   } else if (need_im2col) {
1575     TFLITE_DCHECK(im2col_data);
1576     const int input_zero_point = -input_offset;
1577     TFLITE_DCHECK_GE(input_zero_point, 0);
1578     TFLITE_DCHECK_LE(input_zero_point, 255);
1579     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1580            input_data, im2col_shape, im2col_data);
1581     gemm_input_data = im2col_data;
1582     gemm_input_shape = &im2col_shape;
1583   } else {
1584     TFLITE_DCHECK(!im2col_data);
1585     gemm_input_data = input_data;
1586     gemm_input_shape = &input_shape;
1587   }
1588 
1589   const int gemm_input_rows = gemm_input_shape->Dims(3);
1590   // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
1591   // The root cause has not yet been identified though. Same applies below for
1592   // the other calls commented out. This is a partial rollback of cl/196819423.
1593   // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1594   const int gemm_input_cols = gemm_input_shape->Dims(0) *
1595                               gemm_input_shape->Dims(1) *
1596                               gemm_input_shape->Dims(2);
1597   const int filter_rows = filter_shape.Dims(0);
1598   // See b/79927784.
1599   // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1600   const int filter_cols =
1601       filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
1602   const int output_rows = output_shape.Dims(3);
1603   // See b/79927784.
1604   // const int output_cols = FlatSizeSkipDim(output_shape, 3);
1605   const int output_cols =
1606       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1607   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1608   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1609   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1610   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1611 
1612   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
1613   lhs_params.rows = filter_rows;
1614   lhs_params.cols = filter_cols;
1615   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1616   lhs_params.zero_point = -filter_offset;
1617   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
1618   rhs_params.rows = gemm_input_rows;
1619   rhs_params.cols = gemm_input_cols;
1620   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1621   rhs_params.zero_point = -input_offset;
1622   cpu_backend_gemm::MatrixParams<uint8> dst_params;
1623   dst_params.rows = output_rows;
1624   dst_params.cols = output_cols;
1625   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1626   dst_params.zero_point = output_offset;
1627   cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
1628   gemm_params.bias = bias_data;
1629   gemm_params.clamp_min = output_activation_min;
1630   gemm_params.clamp_max = output_activation_max;
1631   gemm_params.multiplier_fixedpoint = output_multiplier;
1632   gemm_params.multiplier_exponent = output_shift;
1633   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1634                          dst_params, output_data, gemm_params,
1635                          cpu_backend_context);
1636 }
1637 
1638 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1639 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
1640                          const RuntimeShape& unextended_input_shape,
1641                          const T* input_data,
1642                          const RuntimeShape& unextended_output_shape,
1643                          T* output_data) {
1644   ruy::profiler::ScopeLabel label("DepthToSpace");
1645 
1646   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1647   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1648   const RuntimeShape input_shape =
1649       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1650   const RuntimeShape output_shape =
1651       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1652 
1653   const int input_depth = input_shape.Dims(3);
1654   const int input_width = input_shape.Dims(2);
1655   const int input_height = input_shape.Dims(1);
1656 
1657   const int output_depth = output_shape.Dims(3);
1658   const int batch_size = output_shape.Dims(0);
1659 
1660   // Number of continuous values that we can copy in one interation.
1661   const int stride = op_params.block_size * output_depth;
1662 
1663   for (int batch = 0; batch < batch_size; ++batch) {
1664     for (int in_h = 0; in_h < input_height; ++in_h) {
1665       const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
1666       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1667         const T* src = input_ptr;
1668         for (int in_w = 0; in_w < input_width; ++in_w) {
1669           memcpy(output_data, src, stride * sizeof(T));
1670           output_data += stride;
1671           src += input_depth;
1672         }
1673         input_ptr += stride;
1674       }
1675     }
1676   }
1677 }
1678 
1679 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1680 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
1681                          const RuntimeShape& unextended_input_shape,
1682                          const T* input_data,
1683                          const RuntimeShape& unextended_output_shape,
1684                          T* output_data) {
1685   ruy::profiler::ScopeLabel label("SpaceToDepth");
1686 
1687   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1688   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1689   const RuntimeShape input_shape =
1690       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1691   const RuntimeShape output_shape =
1692       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1693 
1694   const int output_depth = output_shape.Dims(3);
1695   const int output_width = output_shape.Dims(2);
1696   const int output_height = output_shape.Dims(1);
1697 
1698   const int input_depth = input_shape.Dims(3);
1699   const int batch_size = input_shape.Dims(0);
1700 
1701   // Number of continuous values that we can copy in one interation.
1702   const int stride = op_params.block_size * input_depth;
1703 
1704   for (int batch = 0; batch < batch_size; ++batch) {
1705     for (int out_h = 0; out_h < output_height; ++out_h) {
1706       T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
1707       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1708         T* dst = output_ptr;
1709         for (int out_w = 0; out_w < output_width; ++out_w) {
1710           memcpy(dst, input_data, stride * sizeof(T));
1711           input_data += stride;
1712           dst += output_depth;
1713         }
1714         output_ptr += stride;
1715       }
1716     }
1717   }
1718 }
1719 
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1720 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
1721                  const RuntimeShape& output_shape, float* output_data) {
1722   ruy::profiler::ScopeLabel label("Relu (not fused)");
1723 
1724   const auto input = MapAsVector(input_data, input_shape);
1725   auto output = MapAsVector(output_data, output_shape);
1726   output = input.cwiseMax(0.0f);
1727 }
1728 
1729 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1730                             const RuntimeShape& input_shape,
1731                             const float* input_data,
1732                             const RuntimeShape& output_shape,
1733                             float* output_data, float epsilon = 1e-6) {
1734   ruy::profiler::ScopeLabel label("L2Normalization");
1735   const int trailing_dim = input_shape.DimensionsCount() - 1;
1736   const int outer_size =
1737       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1738   const int depth =
1739       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1740   for (int i = 0; i < outer_size; ++i) {
1741     float squared_l2_norm = 0;
1742     for (int c = 0; c < depth; ++c) {
1743       const float val = input_data[c];
1744       squared_l2_norm += val * val;
1745     }
1746     float l2_norm = std::sqrt(squared_l2_norm);
1747     l2_norm = std::max(l2_norm, epsilon);
1748     for (int c = 0; c < depth; ++c) {
1749       *output_data = *input_data / l2_norm;
1750       ++output_data;
1751       ++input_data;
1752     }
1753   }
1754 }
1755 
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)1756 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1757                             const RuntimeShape& input_shape,
1758                             const uint8* input_data,
1759                             const RuntimeShape& output_shape,
1760                             uint8* output_data) {
1761   ruy::profiler::ScopeLabel label("L2Normalization/8bit");
1762   const int trailing_dim = input_shape.DimensionsCount() - 1;
1763   const int depth =
1764       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1765   const int outer_size =
1766       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1767   const int32 input_zero_point = op_params.input_zero_point;
1768   for (int i = 0; i < outer_size; ++i) {
1769     int32 square_l2_norm = 0;
1770     for (int c = 0; c < depth; c++) {
1771       // Note that input_data advances by depth in the second pass below.
1772       int32 diff = input_data[c] - input_zero_point;
1773       square_l2_norm += diff * diff;
1774     }
1775     // TODO(b/29395854): add clamping to TOCO and TF Lite kernel
1776     // for all zero tensors in the input_data
1777     int32 inv_l2norm_multiplier;
1778     int inv_l2norm_shift;
1779     GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
1780                                      &inv_l2norm_multiplier, &inv_l2norm_shift);
1781 
1782     for (int c = 0; c < depth; c++) {
1783       int32 diff = *input_data - input_zero_point;
1784       int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
1785           128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
1786       int32 unclamped_output_val = 128 + rescaled_diff;
1787       int32 output_val = std::min(255, std::max(0, unclamped_output_val));
1788       *output_data = static_cast<uint8>(output_val);
1789       ++input_data;
1790       ++output_data;
1791     }
1792   }
1793 }
1794 
AddElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1795 inline void AddElementwise(int size, const ArithmeticParams& params,
1796                            const float* input1_data, const float* input2_data,
1797                            float* output_data) {
1798   int i = 0;
1799 
1800 #ifdef USE_NEON
1801   const auto activation_min = vdupq_n_f32(params.float_activation_min);
1802   const auto activation_max = vdupq_n_f32(params.float_activation_max);
1803   for (; i <= size - 16; i += 16) {
1804     auto a10 = vld1q_f32(input1_data + i);
1805     auto a11 = vld1q_f32(input1_data + i + 4);
1806     auto a12 = vld1q_f32(input1_data + i + 8);
1807     auto a13 = vld1q_f32(input1_data + i + 12);
1808     auto a20 = vld1q_f32(input2_data + i);
1809     auto a21 = vld1q_f32(input2_data + i + 4);
1810     auto a22 = vld1q_f32(input2_data + i + 8);
1811     auto a23 = vld1q_f32(input2_data + i + 12);
1812     auto x0 = vaddq_f32(a10, a20);
1813     auto x1 = vaddq_f32(a11, a21);
1814     auto x2 = vaddq_f32(a12, a22);
1815     auto x3 = vaddq_f32(a13, a23);
1816     x0 = vmaxq_f32(activation_min, x0);
1817     x1 = vmaxq_f32(activation_min, x1);
1818     x2 = vmaxq_f32(activation_min, x2);
1819     x3 = vmaxq_f32(activation_min, x3);
1820     x0 = vminq_f32(activation_max, x0);
1821     x1 = vminq_f32(activation_max, x1);
1822     x2 = vminq_f32(activation_max, x2);
1823     x3 = vminq_f32(activation_max, x3);
1824     vst1q_f32(output_data + i, x0);
1825     vst1q_f32(output_data + i + 4, x1);
1826     vst1q_f32(output_data + i + 8, x2);
1827     vst1q_f32(output_data + i + 12, x3);
1828   }
1829   for (; i <= size - 4; i += 4) {
1830     auto a1 = vld1q_f32(input1_data + i);
1831     auto a2 = vld1q_f32(input2_data + i);
1832     auto x = vaddq_f32(a1, a2);
1833     x = vmaxq_f32(activation_min, x);
1834     x = vminq_f32(activation_max, x);
1835     vst1q_f32(output_data + i, x);
1836   }
1837 #endif  // NEON
1838 
1839   for (; i < size; i++) {
1840     auto x = input1_data[i] + input2_data[i];
1841     output_data[i] = ActivationFunctionWithMinMax(
1842         x, params.float_activation_min, params.float_activation_max);
1843   }
1844 }
1845 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1846 inline void Add(const ArithmeticParams& params,
1847                 const RuntimeShape& input1_shape, const float* input1_data,
1848                 const RuntimeShape& input2_shape, const float* input2_data,
1849                 const RuntimeShape& output_shape, float* output_data) {
1850   ruy::profiler::ScopeLabel label("Add");
1851   const int flat_size =
1852       MatchingElementsSize(input1_shape, input2_shape, output_shape);
1853   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1854 }
1855 
1856 // Element-wise add that can often be used for inner loop of broadcast add as
1857 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)1858 inline void AddElementwise(int size, const ArithmeticParams& params,
1859                            const uint8* input1_data, const uint8* input2_data,
1860                            uint8* output_data) {
1861   ruy::profiler::ScopeLabel label("AddElementwise/8bit");
1862   int i = 0;
1863   TFLITE_DCHECK_GT(params.input1_offset, -256);
1864   TFLITE_DCHECK_GT(params.input2_offset, -256);
1865   TFLITE_DCHECK_LT(params.input1_offset, 256);
1866   TFLITE_DCHECK_LT(params.input2_offset, 256);
1867 #ifdef USE_NEON
1868   const uint8x8_t output_activation_min_vector =
1869       vdup_n_u8(params.quantized_activation_min);
1870   const uint8x8_t output_activation_max_vector =
1871       vdup_n_u8(params.quantized_activation_max);
1872   for (; i <= size - 8; i += 8) {
1873     const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
1874     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1875     const int16x8_t input1_val_s16 =
1876         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1877     const int16x8_t input2_val_s16 =
1878         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1879     const int16x8_t input1_val =
1880         vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1881     const int16x8_t input2_val =
1882         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1883     const int16x4_t input1_val_high = vget_high_s16(input1_val);
1884     const int16x4_t input1_val_low = vget_low_s16(input1_val);
1885     const int16x4_t input2_val_high = vget_high_s16(input2_val);
1886     const int16x4_t input2_val_low = vget_low_s16(input2_val);
1887     int32x4_t x11 = vmovl_s16(input1_val_low);
1888     int32x4_t x12 = vmovl_s16(input1_val_high);
1889     int32x4_t x21 = vmovl_s16(input2_val_low);
1890     int32x4_t x22 = vmovl_s16(input2_val_high);
1891     const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1892     x11 = vshlq_s32(x11, left_shift_dup);
1893     x12 = vshlq_s32(x12, left_shift_dup);
1894     x21 = vshlq_s32(x21, left_shift_dup);
1895     x22 = vshlq_s32(x22, left_shift_dup);
1896     x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1897     x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1898     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1899     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1900     const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1901     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1902     x11 = vshlq_s32(x11, input1_shift_dup);
1903     x12 = vshlq_s32(x12, input1_shift_dup);
1904     x21 = vshlq_s32(x21, input2_shift_dup);
1905     x22 = vshlq_s32(x22, input2_shift_dup);
1906     int32x4_t s1 = vaddq_s32(x11, x21);
1907     int32x4_t s2 = vaddq_s32(x12, x22);
1908     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1909     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1910     using gemmlowp::RoundingDivideByPOT;
1911     s1 = RoundingDivideByPOT(s1, -params.output_shift);
1912     s2 = RoundingDivideByPOT(s2, -params.output_shift);
1913     const int16x4_t s1_narrowed = vmovn_s32(s1);
1914     const int16x4_t s2_narrowed = vmovn_s32(s2);
1915     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1916                                   vdupq_n_s16(params.output_offset));
1917     const uint8x8_t clamped =
1918         vmax_u8(output_activation_min_vector,
1919                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1920     vst1_u8(output_data + i, clamped);
1921   }
1922 #endif  // NEON
1923 
1924   for (; i < size; ++i) {
1925     const int32 input1_val = params.input1_offset + input1_data[i];
1926     const int32 input2_val = params.input2_offset + input2_data[i];
1927     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1928     const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1929     const int32 scaled_input1_val =
1930         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1931             shifted_input1_val, params.input1_multiplier, params.input1_shift);
1932     const int32 scaled_input2_val =
1933         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1934             shifted_input2_val, params.input2_multiplier, params.input2_shift);
1935     const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1936     const int32 raw_output =
1937         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1938             raw_sum, params.output_multiplier, params.output_shift) +
1939         params.output_offset;
1940     const int32 clamped_output =
1941         std::min(params.quantized_activation_max,
1942                  std::max(params.quantized_activation_min, raw_output));
1943     output_data[i] = static_cast<uint8>(clamped_output);
1944   }
1945 }
1946 
1947 // Scalar-broadcast add that can be used for inner loop of more general
1948 // broadcast add, so that, for example, scalar-broadcast with batch will still
1949 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)1950 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1951                                uint8 input1_data, const uint8* input2_data,
1952                                uint8* output_data) {
1953   using gemmlowp::RoundingDivideByPOT;
1954 
1955   ruy::profiler::ScopeLabel label("AddScalarBroadcast/8bit");
1956   TFLITE_DCHECK_GT(params.input1_offset, -256);
1957   TFLITE_DCHECK_GT(params.input2_offset, -256);
1958   TFLITE_DCHECK_LT(params.input1_offset, 256);
1959   TFLITE_DCHECK_LT(params.input2_offset, 256);
1960 
1961   int i = 0;
1962 
1963 #ifdef USE_NEON
1964   const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1965   const uint8x8_t output_activation_min_vector =
1966       vdup_n_u8(params.quantized_activation_min);
1967   const uint8x8_t output_activation_max_vector =
1968       vdup_n_u8(params.quantized_activation_max);
1969 
1970   // Process broadcast scalar.
1971   const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
1972   const int16x8_t input1_val_s16 =
1973       vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1974   const int16x8_t input1_val =
1975       vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1976   const int16x4_t input1_val_high = vget_high_s16(input1_val);
1977   const int16x4_t input1_val_low = vget_low_s16(input1_val);
1978   int32x4_t x11 = vmovl_s16(input1_val_low);
1979   int32x4_t x12 = vmovl_s16(input1_val_high);
1980   x11 = vshlq_s32(x11, left_shift_dup);
1981   x12 = vshlq_s32(x12, left_shift_dup);
1982   x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1983   x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1984   const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1985   x11 = vshlq_s32(x11, input1_shift_dup);
1986   x12 = vshlq_s32(x12, input1_shift_dup);
1987 
1988   for (; i <= size - 8; i += 8) {
1989     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1990     const int16x8_t input2_val_s16 =
1991         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1992     const int16x8_t input2_val =
1993         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1994     const int16x4_t input2_val_high = vget_high_s16(input2_val);
1995     const int16x4_t input2_val_low = vget_low_s16(input2_val);
1996     int32x4_t x21 = vmovl_s16(input2_val_low);
1997     int32x4_t x22 = vmovl_s16(input2_val_high);
1998     x21 = vshlq_s32(x21, left_shift_dup);
1999     x22 = vshlq_s32(x22, left_shift_dup);
2000     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2001     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2002     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2003     x21 = vshlq_s32(x21, input2_shift_dup);
2004     x22 = vshlq_s32(x22, input2_shift_dup);
2005     int32x4_t s1 = vaddq_s32(x11, x21);
2006     int32x4_t s2 = vaddq_s32(x12, x22);
2007     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2008     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2009     s1 = RoundingDivideByPOT(s1, -params.output_shift);
2010     s2 = RoundingDivideByPOT(s2, -params.output_shift);
2011     const int16x4_t s1_narrowed = vmovn_s32(s1);
2012     const int16x4_t s2_narrowed = vmovn_s32(s2);
2013     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2014                                   vdupq_n_s16(params.output_offset));
2015     const uint8x8_t clamped =
2016         vmax_u8(output_activation_min_vector,
2017                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2018     vst1_u8(output_data + i, clamped);
2019   }
2020 #endif  // NEON
2021 
2022   if (i < size) {
2023     // Process broadcast scalar.
2024     const int32 input1_val = params.input1_offset + input1_data;
2025     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2026     const int32 scaled_input1_val =
2027         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2028             shifted_input1_val, params.input1_multiplier, params.input1_shift);
2029 
2030     for (; i < size; ++i) {
2031       const int32 input2_val = params.input2_offset + input2_data[i];
2032       const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2033       const int32 scaled_input2_val =
2034           MultiplyByQuantizedMultiplierSmallerThanOneExp(
2035               shifted_input2_val, params.input2_multiplier,
2036               params.input2_shift);
2037       const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2038       const int32 raw_output =
2039           MultiplyByQuantizedMultiplierSmallerThanOneExp(
2040               raw_sum, params.output_multiplier, params.output_shift) +
2041           params.output_offset;
2042       const int32 clamped_output =
2043           std::min(params.quantized_activation_max,
2044                    std::max(params.quantized_activation_min, raw_output));
2045       output_data[i] = static_cast<uint8>(clamped_output);
2046     }
2047   }
2048 }
2049 
2050 // Scalar-broadcast add that can be used for inner loop of more general
2051 // broadcast add, so that, for example, scalar-broadcast with batch will still
2052 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,float broadcast_value,const float * input2_data,float * output_data)2053 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
2054                                float broadcast_value, const float* input2_data,
2055                                float* output_data) {
2056   int i = 0;
2057 #ifdef USE_NEON
2058   const float32x4_t output_activation_min_vector =
2059       vdupq_n_f32(params.float_activation_min);
2060   const float32x4_t output_activation_max_vector =
2061       vdupq_n_f32(params.float_activation_max);
2062   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2063   for (; i <= size - 4; i += 4) {
2064     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2065 
2066     const float32x4_t output =
2067         vaddq_f32(input2_val_original, broadcast_value_dup);
2068 
2069     const float32x4_t clamped =
2070         vmaxq_f32(output_activation_min_vector,
2071                   vminq_f32(output_activation_max_vector, output));
2072     vst1q_f32(output_data + i, clamped);
2073   }
2074 #endif  // NEON
2075 
2076   for (; i < size; ++i) {
2077     auto x = broadcast_value + input2_data[i];
2078     output_data[i] = ActivationFunctionWithMinMax(
2079         x, params.float_activation_min, params.float_activation_max);
2080   }
2081 }
2082 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2083 inline void Add(const ArithmeticParams& params,
2084                 const RuntimeShape& input1_shape, const uint8* input1_data,
2085                 const RuntimeShape& input2_shape, const uint8* input2_data,
2086                 const RuntimeShape& output_shape, uint8* output_data) {
2087   TFLITE_DCHECK_LE(params.quantized_activation_min,
2088                    params.quantized_activation_max);
2089   ruy::profiler::ScopeLabel label("Add/8bit");
2090   const int flat_size =
2091       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2092 
2093   TFLITE_DCHECK_GT(params.input1_offset, -256);
2094   TFLITE_DCHECK_GT(params.input2_offset, -256);
2095   TFLITE_DCHECK_LT(params.input1_offset, 256);
2096   TFLITE_DCHECK_LT(params.input2_offset, 256);
2097   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
2098 }
2099 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2100 inline void Add(const ArithmeticParams& params,
2101                 const RuntimeShape& input1_shape, const int16* input1_data,
2102                 const RuntimeShape& input2_shape, const int16* input2_data,
2103                 const RuntimeShape& output_shape, int16* output_data) {
2104   ruy::profiler::ScopeLabel label("Add/Int16");
2105   TFLITE_DCHECK_LE(params.quantized_activation_min,
2106                    params.quantized_activation_max);
2107 
2108   const int input1_shift = params.input1_shift;
2109   const int flat_size =
2110       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2111   const int16 output_activation_min = params.quantized_activation_min;
2112   const int16 output_activation_max = params.quantized_activation_max;
2113 
2114   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
2115   TFLITE_DCHECK_LE(input1_shift, 0);
2116   TFLITE_DCHECK_LE(params.input2_shift, 0);
2117   const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
2118   const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
2119   const int input_right_shift =
2120       input1_shift == 0 ? -params.input2_shift : -input1_shift;
2121 
2122   for (int i = 0; i < flat_size; i++) {
2123     // F0 uses 0 integer bits, range [-1, 1].
2124     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2125 
2126     F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
2127     F0 scaled_input = F0::FromRaw(
2128         gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
2129     F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
2130     const int16 raw_output = result.raw();
2131     const int16 clamped_output = std::min(
2132         output_activation_max, std::max(output_activation_min, raw_output));
2133     output_data[i] = clamped_output;
2134   }
2135 }
2136 
2137 template <typename T>
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2138 inline typename std::enable_if<is_int32_or_int64<T>::value, void>::type Add(
2139     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2140     const T* input1_data, const RuntimeShape& input2_shape,
2141     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2142   ruy::profiler::ScopeLabel label("Add/int32or64");
2143 
2144   T activation_min, activation_max;
2145   GetActivationParams(params, &activation_min, &activation_max);
2146 
2147   auto input1_map = MapAsVector(input1_data, input1_shape);
2148   auto input2_map = MapAsVector(input2_data, input2_shape);
2149   auto output_map = MapAsVector(output_data, output_shape);
2150   if (input1_shape == input2_shape) {
2151     output_map.array() = (input1_map.array() + input2_map.array())
2152                              .cwiseMax(activation_min)
2153                              .cwiseMin(activation_max);
2154   } else if (input2_shape.FlatSize() == 1) {
2155     auto scalar = input2_data[0];
2156     output_map.array() = (input1_map.array() + scalar)
2157                              .cwiseMax(activation_min)
2158                              .cwiseMin(activation_max);
2159   } else if (input1_shape.FlatSize() == 1) {
2160     auto scalar = input1_data[0];
2161     output_map.array() = (scalar + input2_map.array())
2162                              .cwiseMax(activation_min)
2163                              .cwiseMin(activation_max);
2164   } else {
2165     reference_ops::BroadcastAdd4DSlow<T>(params, input1_shape, input1_data,
2166                                          input2_shape, input2_data,
2167                                          output_shape, output_data);
2168   }
2169 }
2170 
2171 template <typename T>
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2172 inline void BroadcastAddDispatch(
2173     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2174     const T* input1_data, const RuntimeShape& input2_shape,
2175     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2176   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2177     return BroadcastAdd4DSlow(params, input1_shape, input1_data, input2_shape,
2178                               input2_data, output_shape, output_data);
2179   }
2180 
2181   BinaryBroadcastFiveFold(
2182       params, input1_shape, input1_data, input2_shape, input2_data,
2183       output_shape, output_data,
2184       static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2185                            T*)>(AddElementwise),
2186       static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2187           AddScalarBroadcast));
2188 }
2189 
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2190 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
2191                                  const RuntimeShape& unswitched_input1_shape,
2192                                  const uint8* unswitched_input1_data,
2193                                  const RuntimeShape& unswitched_input2_shape,
2194                                  const uint8* unswitched_input2_data,
2195                                  const RuntimeShape& output_shape,
2196                                  uint8* output_data) {
2197   BroadcastAddDispatch(unswitched_params, unswitched_input1_shape,
2198                        unswitched_input1_data, unswitched_input2_shape,
2199                        unswitched_input2_data, output_shape, output_data);
2200 }
2201 
BroadcastAddFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2202 inline void BroadcastAddFivefold(const ArithmeticParams& params,
2203                                  const RuntimeShape& unswitched_input1_shape,
2204                                  const float* unswitched_input1_data,
2205                                  const RuntimeShape& unswitched_input2_shape,
2206                                  const float* unswitched_input2_data,
2207                                  const RuntimeShape& output_shape,
2208                                  float* output_data) {
2209   BroadcastAddDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2210                        unswitched_input2_shape, unswitched_input2_data,
2211                        output_shape, output_data);
2212 }
2213 
MulElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)2214 inline void MulElementwise(int size, const ArithmeticParams& params,
2215                            const float* input1_data, const float* input2_data,
2216                            float* output_data) {
2217   const float output_activation_min = params.float_activation_min;
2218   const float output_activation_max = params.float_activation_max;
2219 
2220   int i = 0;
2221 #ifdef USE_NEON
2222   const auto activation_min = vdupq_n_f32(output_activation_min);
2223   const auto activation_max = vdupq_n_f32(output_activation_max);
2224   for (; i <= size - 16; i += 16) {
2225     auto a10 = vld1q_f32(input1_data + i);
2226     auto a11 = vld1q_f32(input1_data + i + 4);
2227     auto a12 = vld1q_f32(input1_data + i + 8);
2228     auto a13 = vld1q_f32(input1_data + i + 12);
2229     auto a20 = vld1q_f32(input2_data + i);
2230     auto a21 = vld1q_f32(input2_data + i + 4);
2231     auto a22 = vld1q_f32(input2_data + i + 8);
2232     auto a23 = vld1q_f32(input2_data + i + 12);
2233     auto x0 = vmulq_f32(a10, a20);
2234     auto x1 = vmulq_f32(a11, a21);
2235     auto x2 = vmulq_f32(a12, a22);
2236     auto x3 = vmulq_f32(a13, a23);
2237 
2238     x0 = vmaxq_f32(activation_min, x0);
2239     x1 = vmaxq_f32(activation_min, x1);
2240     x2 = vmaxq_f32(activation_min, x2);
2241     x3 = vmaxq_f32(activation_min, x3);
2242     x0 = vminq_f32(activation_max, x0);
2243     x1 = vminq_f32(activation_max, x1);
2244     x2 = vminq_f32(activation_max, x2);
2245     x3 = vminq_f32(activation_max, x3);
2246 
2247     vst1q_f32(output_data + i, x0);
2248     vst1q_f32(output_data + i + 4, x1);
2249     vst1q_f32(output_data + i + 8, x2);
2250     vst1q_f32(output_data + i + 12, x3);
2251   }
2252   for (; i <= size - 4; i += 4) {
2253     auto a1 = vld1q_f32(input1_data + i);
2254     auto a2 = vld1q_f32(input2_data + i);
2255     auto x = vmulq_f32(a1, a2);
2256 
2257     x = vmaxq_f32(activation_min, x);
2258     x = vminq_f32(activation_max, x);
2259 
2260     vst1q_f32(output_data + i, x);
2261   }
2262 #endif  // NEON
2263 
2264   for (; i < size; i++) {
2265     auto x = input1_data[i] * input2_data[i];
2266     output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
2267                                                   output_activation_max);
2268   }
2269 }
2270 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2271 inline void Mul(const ArithmeticParams& params,
2272                 const RuntimeShape& input1_shape, const float* input1_data,
2273                 const RuntimeShape& input2_shape, const float* input2_data,
2274                 const RuntimeShape& output_shape, float* output_data) {
2275   ruy::profiler::ScopeLabel label("Mul");
2276 
2277   const int flat_size =
2278       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2279   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2280 }
2281 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2282 inline void Mul(const ArithmeticParams& params,
2283                 const RuntimeShape& input1_shape, const int32* input1_data,
2284                 const RuntimeShape& input2_shape, const int32* input2_data,
2285                 const RuntimeShape& output_shape, int32* output_data) {
2286   ruy::profiler::ScopeLabel label("Mul/int32/activation");
2287 
2288   const int flat_size =
2289       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2290   const int32 output_activation_min = params.quantized_activation_min;
2291   const int32 output_activation_max = params.quantized_activation_max;
2292   for (int i = 0; i < flat_size; ++i) {
2293     output_data[i] = ActivationFunctionWithMinMax(
2294         input1_data[i] * input2_data[i], output_activation_min,
2295         output_activation_max);
2296   }
2297 }
2298 
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2299 inline void MulNoActivation(const ArithmeticParams& params,
2300                             const RuntimeShape& input1_shape,
2301                             const int32* input1_data,
2302                             const RuntimeShape& input2_shape,
2303                             const int32* input2_data,
2304                             const RuntimeShape& output_shape,
2305                             int32* output_data) {
2306   ruy::profiler::ScopeLabel label("Mul/int32");
2307 
2308   auto input1_map = MapAsVector(input1_data, input1_shape);
2309   auto input2_map = MapAsVector(input2_data, input2_shape);
2310   auto output_map = MapAsVector(output_data, output_shape);
2311   if (input1_shape == input2_shape) {
2312     output_map.array() = input1_map.array() * input2_map.array();
2313   } else if (input2_shape.FlatSize() == 1) {
2314     auto scalar = input2_data[0];
2315     output_map.array() = input1_map.array() * scalar;
2316   } else if (input1_shape.FlatSize() == 1) {
2317     auto scalar = input1_data[0];
2318     output_map.array() = scalar * input2_map.array();
2319   } else {
2320     reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
2321                                       input2_shape, input2_data, output_shape,
2322                                       output_data);
2323   }
2324 }
2325 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2326 inline void Mul(const ArithmeticParams& params,
2327                 const RuntimeShape& input1_shape, const int16* input1_data,
2328                 const RuntimeShape& input2_shape, const int16* input2_data,
2329                 const RuntimeShape& output_shape, int16* output_data) {
2330   ruy::profiler::ScopeLabel label("Mul/Int16/NoActivation");
2331   // This is a copy of the reference implementation. We do not currently have a
2332   // properly optimized version.
2333 
2334   const int flat_size =
2335       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2336 
2337   for (int i = 0; i < flat_size; i++) {
2338     // F0 uses 0 integer bits, range [-1, 1].
2339     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2340 
2341     F0 unclamped_result =
2342         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2343     output_data[i] = unclamped_result.raw();
2344   }
2345 }
2346 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2347 inline void Mul(const ArithmeticParams& params,
2348                 const RuntimeShape& input1_shape, const int16* input1_data,
2349                 const RuntimeShape& input2_shape, const int16* input2_data,
2350                 const RuntimeShape& output_shape, uint8* output_data) {
2351   ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
2352   // This is a copy of the reference implementation. We do not currently have a
2353   // properly optimized version.
2354   const int32 output_activation_min = params.quantized_activation_min;
2355   const int32 output_activation_max = params.quantized_activation_max;
2356   const int32 output_offset = params.output_offset;
2357   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2358 
2359   const int flat_size =
2360       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2361 
2362   for (int i = 0; i < flat_size; i++) {
2363     // F0 uses 0 integer bits, range [-1, 1].
2364     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2365 
2366     F0 unclamped_result =
2367         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2368     int16 rescaled_result =
2369         gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
2370     int16 clamped_result =
2371         std::min<int16>(output_activation_max - output_offset, rescaled_result);
2372     clamped_result =
2373         std::max<int16>(output_activation_min - output_offset, clamped_result);
2374     output_data[i] = output_offset + clamped_result;
2375   }
2376 }
2377 
2378 // Element-wise mul that can often be used for inner loop of broadcast Mul as
2379 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2380 inline void MulElementwise(int size, const ArithmeticParams& params,
2381                            const uint8* input1_data, const uint8* input2_data,
2382                            uint8* output_data) {
2383   int i = 0;
2384   TFLITE_DCHECK_GT(params.input1_offset, -256);
2385   TFLITE_DCHECK_LT(params.input1_offset, 256);
2386   TFLITE_DCHECK_GT(params.input2_offset, -256);
2387   TFLITE_DCHECK_LT(params.input2_offset, 256);
2388   TFLITE_DCHECK_GT(params.output_offset, -256);
2389   TFLITE_DCHECK_LT(params.output_offset, 256);
2390 #ifdef USE_NEON
2391   const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
2392   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2393   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2394   const auto output_activation_min_vector =
2395       vdup_n_u8(params.quantized_activation_min);
2396   const auto output_activation_max_vector =
2397       vdup_n_u8(params.quantized_activation_max);
2398   const int left_shift = std::max(0, params.output_shift);
2399   const int right_shift = std::max(0, -params.output_shift);
2400   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2401   for (; i <= size - 8; i += 8) {
2402     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2403     const auto input1_val_original = vld1_u8(input1_data + i);
2404     const auto input2_val_original = vld1_u8(input2_data + i);
2405     const auto input1_val_s16 =
2406         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2407     const auto input2_val_s16 =
2408         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2409     const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
2410     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2411 
2412     const auto input1_val_low = vget_low_s16(input1_val);
2413     const auto input1_val_high = vget_high_s16(input1_val);
2414     const auto input2_val_low = vget_low_s16(input2_val);
2415     const auto input2_val_high = vget_high_s16(input2_val);
2416 
2417     auto p1 = vmull_s16(input2_val_low, input1_val_low);
2418     auto p2 = vmull_s16(input2_val_high, input1_val_high);
2419 
2420     p1 = vshlq_s32(p1, left_shift_vec);
2421     p2 = vshlq_s32(p2, left_shift_vec);
2422     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2423     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2424     using gemmlowp::RoundingDivideByPOT;
2425     p1 = RoundingDivideByPOT(p1, right_shift);
2426     p2 = RoundingDivideByPOT(p2, right_shift);
2427 
2428     const auto p1_narrowed = vqmovn_s32(p1);
2429     const auto p2_narrowed = vqmovn_s32(p2);
2430     const auto p =
2431         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2432     const auto clamped =
2433         vmax_u8(output_activation_min_vector,
2434                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2435     vst1_u8(output_data + i, clamped);
2436   }
2437 #endif  // NEON
2438 
2439   for (; i < size; ++i) {
2440     const int32 input1_val = params.input1_offset + input1_data[i];
2441     const int32 input2_val = params.input2_offset + input2_data[i];
2442     const int32 unclamped_result =
2443         params.output_offset +
2444         MultiplyByQuantizedMultiplier(input1_val * input2_val,
2445                                       params.output_multiplier,
2446                                       params.output_shift);
2447     const int32 clamped_output =
2448         std::min(params.quantized_activation_max,
2449                  std::max(params.quantized_activation_min, unclamped_result));
2450     output_data[i] = static_cast<uint8>(clamped_output);
2451   }
2452 }
2453 
2454 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)2455 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2456                                const uint8 broadcast_value,
2457                                const uint8* input2_data, uint8* output_data) {
2458   const int16 input1_val = params.input1_offset + broadcast_value;
2459 
2460   int i = 0;
2461   TFLITE_DCHECK_GT(params.input1_offset, -256);
2462   TFLITE_DCHECK_LT(params.input1_offset, 256);
2463   TFLITE_DCHECK_GT(params.input2_offset, -256);
2464   TFLITE_DCHECK_LT(params.input2_offset, 256);
2465   TFLITE_DCHECK_GT(params.output_offset, -256);
2466   TFLITE_DCHECK_LT(params.output_offset, 256);
2467 #ifdef USE_NEON
2468   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2469   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2470   const auto output_activation_min_vector =
2471       vdup_n_u8(params.quantized_activation_min);
2472   const auto output_activation_max_vector =
2473       vdup_n_u8(params.quantized_activation_max);
2474   const int left_shift = std::max(0, params.output_shift);
2475   const int right_shift = std::max(0, -params.output_shift);
2476   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2477   for (; i <= size - 8; i += 8) {
2478     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2479     const auto input2_val_original = vld1_u8(input2_data + i);
2480     const auto input2_val_s16 =
2481         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2482     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2483 
2484     const auto input2_val_low = vget_low_s16(input2_val);
2485     const auto input2_val_high = vget_high_s16(input2_val);
2486 
2487     auto p1 = vmull_n_s16(input2_val_low, input1_val);
2488     auto p2 = vmull_n_s16(input2_val_high, input1_val);
2489 
2490     p1 = vshlq_s32(p1, left_shift_vec);
2491     p2 = vshlq_s32(p2, left_shift_vec);
2492     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2493     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2494     using gemmlowp::RoundingDivideByPOT;
2495     p1 = RoundingDivideByPOT(p1, right_shift);
2496     p2 = RoundingDivideByPOT(p2, right_shift);
2497 
2498     const auto p1_narrowed = vmovn_s32(p1);
2499     const auto p2_narrowed = vmovn_s32(p2);
2500     const auto p =
2501         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2502     const auto clamped =
2503         vmax_u8(output_activation_min_vector,
2504                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2505     vst1_u8(output_data + i, clamped);
2506   }
2507 #endif  // NEON
2508 
2509   for (; i < size; ++i) {
2510     const int32 input2_val = params.input2_offset + input2_data[i];
2511     const int32 unclamped_result =
2512         params.output_offset +
2513         MultiplyByQuantizedMultiplier(input1_val * input2_val,
2514                                       params.output_multiplier,
2515                                       params.output_shift);
2516     const int32 clamped_output =
2517         std::min(params.quantized_activation_max,
2518                  std::max(params.quantized_activation_min, unclamped_result));
2519     output_data[i] = static_cast<uint8>(clamped_output);
2520   }
2521 }
2522 
2523 // Broadcast mul that can often be used for inner loop of broadcast Mul.
2524 // This function will handle scalar_value (LHS) * vector_values (RHS).
2525 // Since it's a float function, input params does not matter here.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const float broadcast_value,const float * input2_data,float * output_data)2526 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2527                                const float broadcast_value,
2528                                const float* input2_data, float* output_data) {
2529   int i = 0;
2530 #ifdef USE_NEON
2531   const float32x4_t output_activation_min_vector =
2532       vdupq_n_f32(params.float_activation_min);
2533   const float32x4_t output_activation_max_vector =
2534       vdupq_n_f32(params.float_activation_max);
2535   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2536   for (; i <= size - 4; i += 4) {
2537     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2538 
2539     const float32x4_t output =
2540         vmulq_f32(input2_val_original, broadcast_value_dup);
2541 
2542     const float32x4_t clamped =
2543         vmaxq_f32(output_activation_min_vector,
2544                   vminq_f32(output_activation_max_vector, output));
2545     vst1q_f32(output_data + i, clamped);
2546   }
2547 #endif  // NEON
2548 
2549   for (; i < size; ++i) {
2550     float x = broadcast_value * input2_data[i];
2551     output_data[i] = ActivationFunctionWithMinMax(
2552         x, params.float_activation_min, params.float_activation_max);
2553   }
2554 }
2555 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2556 inline void Mul(const ArithmeticParams& params,
2557                 const RuntimeShape& input1_shape, const uint8* input1_data,
2558                 const RuntimeShape& input2_shape, const uint8* input2_data,
2559                 const RuntimeShape& output_shape, uint8* output_data) {
2560   TFLITE_DCHECK_LE(params.quantized_activation_min,
2561                    params.quantized_activation_max);
2562   ruy::profiler::ScopeLabel label("Mul/8bit");
2563   const int flat_size =
2564       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2565 
2566   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2567 }
2568 
2569 template <typename T>
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2570 inline void BroadcastMulDispatch(
2571     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2572     const T* input1_data, const RuntimeShape& input2_shape,
2573     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2574   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2575     return BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
2576                               input2_data, output_shape, output_data);
2577   }
2578 
2579   BinaryBroadcastFiveFold(
2580       params, input1_shape, input1_data, input2_shape, input2_data,
2581       output_shape, output_data,
2582       static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2583                            T*)>(MulElementwise),
2584       static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2585           MulSimpleBroadcast));
2586 }
2587 
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2588 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
2589                                  const RuntimeShape& unswitched_input1_shape,
2590                                  const uint8* unswitched_input1_data,
2591                                  const RuntimeShape& unswitched_input2_shape,
2592                                  const uint8* unswitched_input2_data,
2593                                  const RuntimeShape& output_shape,
2594                                  uint8* output_data) {
2595   BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
2596                        unswitched_input1_data, unswitched_input2_shape,
2597                        unswitched_input2_data, output_shape, output_data);
2598 }
2599 
BroadcastMulFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2600 inline void BroadcastMulFivefold(const ArithmeticParams& params,
2601                                  const RuntimeShape& unswitched_input1_shape,
2602                                  const float* unswitched_input1_data,
2603                                  const RuntimeShape& unswitched_input2_shape,
2604                                  const float* unswitched_input2_data,
2605                                  const RuntimeShape& output_shape,
2606                                  float* output_data) {
2607   BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2608                        unswitched_input2_shape, unswitched_input2_data,
2609                        output_shape, output_data);
2610 }
2611 
2612 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
2613 // dimensionality if the runtime code does a single loop over one dimension
2614 // that handles broadcasting as the base case. The code generator would then
2615 // generate max(D1, D2) nested for loops.
2616 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
2617 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
2618 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
2619 // reference_ops.h.
2620 template <typename T, int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)2621 void BroadcastDivSlow(const ArithmeticParams& params,
2622                       const RuntimeShape& unextended_input1_shape,
2623                       const T* input1_data,
2624                       const RuntimeShape& unextended_input2_shape,
2625                       const T* input2_data,
2626                       const RuntimeShape& unextended_output_shape,
2627                       T* output_data) {
2628   ruy::profiler::ScopeLabel label("BroadcastDivSlow");
2629   T output_activation_min;
2630   T output_activation_max;
2631   GetActivationParams(params, &output_activation_min, &output_activation_max);
2632 
2633   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2634   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2635   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2636 
2637   NdArrayDesc<N> desc1;
2638   NdArrayDesc<N> desc2;
2639   NdArrayDesc<N> output_desc;
2640   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2641                                       unextended_input2_shape, &desc1, &desc2);
2642   CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2643                  &output_desc);
2644 
2645   // In Tensorflow, the dimensions are canonically named (batch_number, row,
2646   // col, channel), with extents (batches, height, width, depth), with the
2647   // trailing dimension changing most rapidly (channels has the smallest stride,
2648   // typically 1 element).
2649   //
2650   // In generated C code, we store arrays with the dimensions reversed. The
2651   // first dimension has smallest stride.
2652   //
2653   // We name our variables by their Tensorflow convention, but generate C code
2654   // nesting loops such that the innermost loop has the smallest stride for the
2655   // best cache behavior.
2656   auto div_func = [&](int indexes[N]) {
2657     output_data[SubscriptToIndex(output_desc, indexes)] =
2658         ActivationFunctionWithMinMax(
2659             input1_data[SubscriptToIndex(desc1, indexes)] /
2660                 input2_data[SubscriptToIndex(desc2, indexes)],
2661             output_activation_min, output_activation_max);
2662   };
2663   NDOpsHelper<N>(output_desc, div_func);
2664 }
2665 
2666 // BroadcastDiv is intentionally duplicated from reference_ops.h.
2667 // For more details see the comment above the generic version of
2668 // BroadcastDivSlow.
2669 template <int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const uint8 * input1_data,const RuntimeShape & unextended_input2_shape,const uint8 * input2_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)2670 inline void BroadcastDivSlow(const ArithmeticParams& params,
2671                              const RuntimeShape& unextended_input1_shape,
2672                              const uint8* input1_data,
2673                              const RuntimeShape& unextended_input2_shape,
2674                              const uint8* input2_data,
2675                              const RuntimeShape& unextended_output_shape,
2676                              uint8* output_data) {
2677   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2678   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2679   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2680 
2681   NdArrayDesc<N> desc1;
2682   NdArrayDesc<N> desc2;
2683   NdArrayDesc<N> output_desc;
2684   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2685                                       unextended_input2_shape, &desc1, &desc2);
2686   CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2687                  &output_desc);
2688 
2689   TFLITE_DCHECK_GT(params.input1_offset, -256);
2690   TFLITE_DCHECK_LT(params.input1_offset, 256);
2691   TFLITE_DCHECK_GT(params.input2_offset, -256);
2692   TFLITE_DCHECK_LT(params.input2_offset, 256);
2693   TFLITE_DCHECK_GT(params.output_offset, -256);
2694   TFLITE_DCHECK_LT(params.output_offset, 256);
2695 
2696   auto div_func = [&](int indexes[N]) {
2697     const int32 input1_val =
2698         params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
2699     const int32 input2_val =
2700         params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
2701     TFLITE_DCHECK_NE(input2_val, 0);
2702     int recip_shift;
2703     const int32 input2_inv =
2704         (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
2705                          : -GetReciprocal(-input2_val, 31, &recip_shift);
2706     const int headroom = CountLeadingSignBits(input1_val);
2707     const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
2708         input1_val, input2_inv, headroom);
2709     const int total_shift = params.output_shift - recip_shift - headroom;
2710     const int32 unclamped_result =
2711         params.output_offset +
2712         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2713             unscaled_quotient, params.output_multiplier, total_shift);
2714     const int32 clamped_output =
2715         std::min(params.quantized_activation_max,
2716                  std::max(params.quantized_activation_min, unclamped_result));
2717     output_data[SubscriptToIndex(output_desc, indexes)] =
2718         static_cast<uint8>(clamped_output);
2719   };
2720   NDOpsHelper<N>(output_desc, div_func);
2721 }
2722 
2723 template <typename T>
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2724 inline void SubWithActivation(
2725     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2726     const T* input1_data, const RuntimeShape& input2_shape,
2727     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2728   ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
2729   TFLITE_DCHECK_EQ(input1_shape.FlatSize(), input2_shape.FlatSize());
2730   auto input1_map = MapAsVector(input1_data, input1_shape);
2731   auto input2_map = MapAsVector(input2_data, input2_shape);
2732   auto output_map = MapAsVector(output_data, output_shape);
2733   T activation_min, activation_max;
2734   GetActivationParams(params, &activation_min, &activation_max);
2735   output_map.array() = (input1_map.array() - input2_map.array())
2736                            .cwiseMin(activation_max)
2737                            .cwiseMax(activation_min);
2738 }
2739 
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2740 inline void SubNonBroadcast(const ArithmeticParams& params,
2741                             const RuntimeShape& input1_shape,
2742                             const float* input1_data,
2743                             const RuntimeShape& input2_shape,
2744                             const float* input2_data,
2745                             const RuntimeShape& output_shape,
2746                             float* output_data) {
2747   ruy::profiler::ScopeLabel label("SubNonBroadcast");
2748   SubWithActivation<float>(params, input1_shape, input1_data, input2_shape,
2749                            input2_data, output_shape, output_data);
2750 }
2751 
2752 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2753 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
2754          const T* input1_data, const RuntimeShape& input2_shape,
2755          const T* input2_data, const RuntimeShape& output_shape,
2756          T* output_data) {
2757   ruy::profiler::ScopeLabel label("Sub");
2758 
2759   auto input1_map = MapAsVector(input1_data, input1_shape);
2760   auto input2_map = MapAsVector(input2_data, input2_shape);
2761   auto output_map = MapAsVector(output_data, output_shape);
2762   if (input1_shape == input2_shape) {
2763     output_map.array() = input1_map.array() - input2_map.array();
2764   } else if (input1_shape.FlatSize() == 1) {
2765     auto scalar = input1_data[0];
2766     output_map.array() = scalar - input2_map.array();
2767   } else if (input2_shape.FlatSize() == 1) {
2768     auto scalar = input2_data[0];
2769     output_map.array() = input1_map.array() - scalar;
2770   } else {
2771     BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
2772                      input2_data, output_shape, output_data);
2773   }
2774 }
2775 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data,CpuBackendContext * cpu_backend_context)2776 inline void LstmCell(
2777     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2778     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2779     const float* prev_activ_data, const RuntimeShape& weights_shape,
2780     const float* weights_data, const RuntimeShape& unextended_bias_shape,
2781     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2782     const float* prev_state_data,
2783     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2784     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2785     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2786     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
2787     CpuBackendContext* cpu_backend_context) {
2788   ruy::profiler::ScopeLabel label("LstmCell");
2789   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2790   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2791   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2792   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2793   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2794   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2795   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2796   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2797   const RuntimeShape input_shape =
2798       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2799   const RuntimeShape prev_activ_shape =
2800       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2801   const RuntimeShape bias_shape =
2802       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2803   const RuntimeShape prev_state_shape =
2804       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2805   const RuntimeShape output_state_shape =
2806       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2807   const RuntimeShape output_activ_shape =
2808       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2809   const RuntimeShape concat_temp_shape =
2810       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2811   const RuntimeShape activ_temp_shape =
2812       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2813   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2814 
2815   const int weights_dim_count = weights_shape.DimensionsCount();
2816   MatchingDim(  // batches
2817       input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
2818       output_state_shape, 0, output_activ_shape, 0);
2819   MatchingDim(  // height
2820       input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
2821       output_state_shape, 1, output_activ_shape, 1);
2822   MatchingDim(  // width
2823       input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
2824       output_state_shape, 2, output_activ_shape, 2);
2825   const int input_depth = input_shape.Dims(3);
2826   const int prev_activ_depth = prev_activ_shape.Dims(3);
2827   const int total_input_depth = prev_activ_depth + input_depth;
2828   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2829                    total_input_depth);
2830   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2831   const int intern_activ_depth =
2832       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2833   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2834                    intern_activ_depth * total_input_depth);
2835   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2836   const int output_depth =
2837       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2838                   3, output_activ_shape, 3);
2839   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2840 
2841   // Concatenate prev_activ and input data together
2842   std::vector<float const*> concat_input_arrays_data;
2843   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
2844   concat_input_arrays_data.push_back(input_data);
2845   concat_input_arrays_data.push_back(prev_activ_data);
2846   concat_input_arrays_shapes.push_back(&input_shape);
2847   concat_input_arrays_shapes.push_back(&prev_activ_shape);
2848   tflite::ConcatenationParams concat_params;
2849   concat_params.axis = 3;
2850   concat_params.inputs_count = concat_input_arrays_data.size();
2851   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
2852                 &(concat_input_arrays_data[0]), concat_temp_shape,
2853                 concat_temp_data);
2854 
2855   // Fully connected
2856   tflite::FullyConnectedParams fc_params;
2857   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
2858   fc_params.float_activation_max = std::numeric_limits<float>::max();
2859   fc_params.lhs_cacheable = false;
2860   fc_params.rhs_cacheable = false;
2861   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
2862                  weights_data, bias_shape, bias_data, activ_temp_shape,
2863                  activ_temp_data, cpu_backend_context);
2864 
2865   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
2866   // operations.
2867   ArrayMap<float> activ_temp_map =
2868       MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
2869   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
2870                                             activ_temp_map.cols());
2871   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
2872                                            activ_temp_map.cols());
2873   auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
2874                                              activ_temp_map.cols());
2875   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
2876                                              activ_temp_map.cols());
2877   ArrayMap<const float> prev_state_map =
2878       MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
2879   ArrayMap<float> output_state_map =
2880       MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
2881   ArrayMap<float> output_activ_map =
2882       MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
2883 
2884   // Combined memory state and final output calculation
2885   ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
2886   output_state_map =
2887       input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2888           new_input_sm.tanh() +
2889       forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2890           prev_state_map;
2891   output_activ_map =
2892       output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2893       output_state_map.tanh();
2894 }
2895 
2896 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,CpuBackendContext * cpu_backend_context)2897 inline void LstmCell(
2898     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2899     const uint8* input_data_uint8,
2900     const RuntimeShape& unextended_prev_activ_shape,
2901     const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
2902     const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
2903     const int32* bias_data_int32,
2904     const RuntimeShape& unextended_prev_state_shape,
2905     const int16* prev_state_data_int16,
2906     const RuntimeShape& unextended_output_state_shape,
2907     int16* output_state_data_int16,
2908     const RuntimeShape& unextended_output_activ_shape,
2909     uint8* output_activ_data_uint8,
2910     const RuntimeShape& unextended_concat_temp_shape,
2911     uint8* concat_temp_data_uint8,
2912     const RuntimeShape& unextended_activ_temp_shape,
2913     int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
2914   ruy::profiler::ScopeLabel label(
2915       "LstmCell/quantized (8bit external, 16bit internal)");
2916   int32 weights_zero_point = params.weights_zero_point;
2917   int32 accum_multiplier = params.accum_multiplier;
2918   int accum_shift = params.accum_shift;
2919   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2920   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2921   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2922   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2923   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2924   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2925   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2926   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2927   const RuntimeShape input_shape =
2928       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2929   const RuntimeShape prev_activ_shape =
2930       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2931   const RuntimeShape bias_shape =
2932       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2933   const RuntimeShape prev_state_shape =
2934       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2935   const RuntimeShape output_state_shape =
2936       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2937   const RuntimeShape output_activ_shape =
2938       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2939   const RuntimeShape concat_temp_shape =
2940       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2941   const RuntimeShape activ_temp_shape =
2942       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2943   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2944 
2945   // Gather dimensions information, and perform consistency checks.
2946   const int weights_dim_count = weights_shape.DimensionsCount();
2947   const int outer_size = MatchingFlatSizeSkipDim(
2948       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
2949       output_activ_shape);
2950   const int input_depth = input_shape.Dims(3);
2951   const int prev_activ_depth = prev_activ_shape.Dims(3);
2952   const int total_input_depth = prev_activ_depth + input_depth;
2953   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2954                    total_input_depth);
2955   const int intern_activ_depth =
2956       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2957   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2958                    intern_activ_depth * total_input_depth);
2959   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2960   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2961   const int output_depth =
2962       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2963                   3, output_activ_shape, 3);
2964   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2965   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
2966   const int fc_output_depth =
2967       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
2968   const int fc_accum_depth = total_input_depth;
2969   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
2970 
2971   // Depth-concatenate prev_activ and input data together.
2972   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
2973                                               prev_activ_data_uint8};
2974   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
2975                                                        &prev_activ_shape};
2976   tflite::ConcatenationParams concat_params;
2977   concat_params.axis = 3;
2978   concat_params.inputs_count = 2;
2979   Concatenation(concat_params, concat_input_arrays_shapes,
2980                 concat_input_arrays_data, concat_temp_shape,
2981                 concat_temp_data_uint8);
2982 
2983   // Implementation of the fully connected node inside the LSTM cell.
2984   // The operands are 8-bit integers, the accumulators are internally 32bit
2985   // integers, and the output is 16-bit fixed-point with 3 integer bits so
2986   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
2987   // is explained in the function comment above.
2988   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
2989   lhs_params.rows = fc_output_depth;
2990   lhs_params.cols = fc_accum_depth;
2991   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
2992   lhs_params.zero_point = weights_zero_point;
2993   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
2994   rhs_params.rows = fc_accum_depth;
2995   rhs_params.cols = fc_batches;
2996   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
2997   rhs_params.zero_point = 128;
2998   cpu_backend_gemm::MatrixParams<int16> dst_params;
2999   dst_params.rows = fc_output_depth;
3000   dst_params.cols = fc_batches;
3001   dst_params.order = cpu_backend_gemm::Order::kColMajor;
3002   dst_params.zero_point = 0;
3003   cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
3004   gemm_params.bias = bias_data_int32;
3005   gemm_params.multiplier_fixedpoint = accum_multiplier;
3006   gemm_params.multiplier_exponent = accum_shift;
3007   cpu_backend_gemm::Gemm(
3008       lhs_params, weights_data_uint8, rhs_params, concat_temp_data_uint8,
3009       dst_params, activ_temp_data_int16, gemm_params, cpu_backend_context);
3010 
3011   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3012   // and muls, all done in 16-bit fixed-point.
3013   const int16* input_gate_input_ptr = activ_temp_data_int16;
3014   const int16* input_modulation_gate_input_ptr =
3015       activ_temp_data_int16 + output_depth;
3016   const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3017   const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3018   const int16* prev_state_ptr = prev_state_data_int16;
3019   int16* output_state_data_ptr = output_state_data_int16;
3020   uint8* output_activ_data_ptr = output_activ_data_uint8;
3021 
3022   for (int b = 0; b < outer_size; ++b) {
3023     int c = 0;
3024 #ifdef GEMMLOWP_NEON
3025     for (; c <= output_depth - 8; c += 8) {
3026       // Define the fixed-point data types that we will use here. All use
3027       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3028       // They only differ by the number of integral vs. fractional bits,
3029       // determining the range of values that they can represent.
3030       //
3031       // F0 uses 0 integer bits, range [-1, 1].
3032       // This is the return type of math functions such as tanh, logistic,
3033       // whose range is in [-1, 1].
3034       using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3035       // F3 uses 3 integer bits, range [-8, 8].
3036       // This is the range of the previous fully-connected node's output,
3037       // which is our input here.
3038       using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3039       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3040       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3041       // number of integer bits is currently dictated by the model. See comment
3042       // on the StateIntegerBits template parameter above.
3043       using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3044       // Implementation of input gate, using fixed-point logistic function.
3045       F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3046       input_gate_input_ptr += 8;
3047       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3048       // Implementation of input modulation gate, using fixed-point tanh
3049       // function.
3050       F3 input_modulation_gate_input =
3051           F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3052       input_modulation_gate_input_ptr += 8;
3053       F0 input_modulation_gate_output =
3054           gemmlowp::tanh(input_modulation_gate_input);
3055       // Implementation of forget gate, using fixed-point logistic function.
3056       F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3057       forget_gate_input_ptr += 8;
3058       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3059       // Implementation of output gate, using fixed-point logistic function.
3060       F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3061       output_gate_input_ptr += 8;
3062       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3063       // Implementation of internal multiplication nodes, still in fixed-point.
3064       F0 input_times_input_modulation =
3065           input_gate_output * input_modulation_gate_output;
3066       FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3067       prev_state_ptr += 8;
3068       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3069       // Implementation of internal addition node, saturating.
3070       FS new_state = gemmlowp::SaturatingAdd(
3071           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3072           prev_state_times_forget_state);
3073       // Implementation of last internal Tanh node, still in fixed-point.
3074       // Since a Tanh fixed-point implementation is specialized for a given
3075       // number or integer bits, and each specialization can have a substantial
3076       // code size, and we already used above a Tanh on an input with 3 integer
3077       // bits, and per the table in the above function comment there is no
3078       // significant accuracy to be lost by clamping to [-8, +8] for a
3079       // 3-integer-bits representation, let us just do that. This helps people
3080       // porting this to targets where code footprint must be minimized.
3081       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3082       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3083       // Store the new internal state back to memory, as 16-bit integers.
3084       // Note: here we store the original value with StateIntegerBits, not
3085       // the rescaled 3-integer-bits value fed to tanh.
3086       vst1q_s16(output_state_data_ptr, new_state.raw());
3087       output_state_data_ptr += 8;
3088       // Down-scale the output activations to 8-bit integers, saturating,
3089       // and store back to memory.
3090       int16x8_t rescaled_output_activ =
3091           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3092       int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3093       uint8x8_t uint8_output_activ =
3094           vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3095       vst1_u8(output_activ_data_ptr, uint8_output_activ);
3096       output_activ_data_ptr += 8;
3097     }
3098 #endif
3099     for (; c < output_depth; ++c) {
3100       // Define the fixed-point data types that we will use here. All use
3101       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3102       // They only differ by the number of integral vs. fractional bits,
3103       // determining the range of values that they can represent.
3104       //
3105       // F0 uses 0 integer bits, range [-1, 1].
3106       // This is the return type of math functions such as tanh, logistic,
3107       // whose range is in [-1, 1].
3108       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3109       // F3 uses 3 integer bits, range [-8, 8].
3110       // This is the range of the previous fully-connected node's output,
3111       // which is our input here.
3112       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3113       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3114       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3115       // number of integer bits is currently dictated by the model. See comment
3116       // on the StateIntegerBits template parameter above.
3117       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3118       // Implementation of input gate, using fixed-point logistic function.
3119       F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3120       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3121       // Implementation of input modulation gate, using fixed-point tanh
3122       // function.
3123       F3 input_modulation_gate_input =
3124           F3::FromRaw(*input_modulation_gate_input_ptr++);
3125       F0 input_modulation_gate_output =
3126           gemmlowp::tanh(input_modulation_gate_input);
3127       // Implementation of forget gate, using fixed-point logistic function.
3128       F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3129       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3130       // Implementation of output gate, using fixed-point logistic function.
3131       F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3132       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3133       // Implementation of internal multiplication nodes, still in fixed-point.
3134       F0 input_times_input_modulation =
3135           input_gate_output * input_modulation_gate_output;
3136       FS prev_state = FS::FromRaw(*prev_state_ptr++);
3137       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3138       // Implementation of internal addition node, saturating.
3139       FS new_state = gemmlowp::SaturatingAdd(
3140           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3141           prev_state_times_forget_state);
3142       // Implementation of last internal Tanh node, still in fixed-point.
3143       // Since a Tanh fixed-point implementation is specialized for a given
3144       // number or integer bits, and each specialization can have a substantial
3145       // code size, and we already used above a Tanh on an input with 3 integer
3146       // bits, and per the table in the above function comment there is no
3147       // significant accuracy to be lost by clamping to [-8, +8] for a
3148       // 3-integer-bits representation, let us just do that. This helps people
3149       // porting this to targets where code footprint must be minimized.
3150       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3151       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3152       // Store the new internal state back to memory, as 16-bit integers.
3153       // Note: here we store the original value with StateIntegerBits, not
3154       // the rescaled 3-integer-bits value fed to tanh.
3155       *output_state_data_ptr++ = new_state.raw();
3156       // Down-scale the output activations to 8-bit integers, saturating,
3157       // and store back to memory.
3158       int16 rescaled_output_activ =
3159           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3160       int16 clamped_output_activ =
3161           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3162       *output_activ_data_ptr++ = 128 + clamped_output_activ;
3163     }
3164     input_gate_input_ptr += 3 * output_depth;
3165     input_modulation_gate_input_ptr += 3 * output_depth;
3166     forget_gate_input_ptr += 3 * output_depth;
3167     output_gate_input_ptr += 3 * output_depth;
3168   }
3169 }
3170 
NodeOffset(int b,int h,int w,int height,int width)3171 inline int NodeOffset(int b, int h, int w, int height, int width) {
3172   return (b * height + h) * width + w;
3173 }
3174 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3175 inline bool AveragePool(const PoolParams& params,
3176                         const RuntimeShape& input_shape,
3177                         const float* input_data,
3178                         const RuntimeShape& output_shape, float* output_data) {
3179   ruy::profiler::ScopeLabel label("AveragePool");
3180   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3181   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3182   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3183   const int input_height = input_shape.Dims(1);
3184   const int input_width = input_shape.Dims(2);
3185   const int output_height = output_shape.Dims(1);
3186   const int output_width = output_shape.Dims(2);
3187   const int stride_height = params.stride_height;
3188   const int stride_width = params.stride_width;
3189 
3190   if (stride_height == 0) return false;
3191   if (stride_width == 0) return false;
3192 
3193   // TODO(benoitjacob) make this a proper reference impl without Eigen!
3194   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3195   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3196   // TODO(benoitjacob) get rid of the dynamic memory allocation here!
3197   Eigen::VectorXf out_count(out_mat.cols());
3198   out_count.setZero();
3199   // Prefill the output to 0.
3200   out_mat.setZero();
3201   for (int b = 0; b < batches; ++b) {
3202     for (int h = 0; h < input_height; ++h) {
3203       for (int w = 0; w < input_width; ++w) {
3204         // (h_start, h_end) * (w_start, w_end) is the range that the input
3205         // vector projects to.
3206         int hpad = h + params.padding_values.height;
3207         int wpad = w + params.padding_values.width;
3208         int h_start = (hpad < params.filter_height)
3209                           ? 0
3210                           : (hpad - params.filter_height) / stride_height + 1;
3211         int h_end = std::min(hpad / stride_height + 1, output_height);
3212         int w_start = (wpad < params.filter_width)
3213                           ? 0
3214                           : (wpad - params.filter_width) / stride_width + 1;
3215         int w_end = std::min(wpad / stride_width + 1, output_width);
3216         // compute elementwise sum
3217         for (int ph = h_start; ph < h_end; ++ph) {
3218           for (int pw = w_start; pw < w_end; ++pw) {
3219             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3220             out_mat.col(out_offset) +=
3221                 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
3222             out_count(out_offset)++;
3223           }
3224         }
3225       }
3226     }
3227   }
3228   // Divide the output by the actual number of elements being averaged over
3229   TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
3230   out_mat.array().rowwise() /= out_count.transpose().array();
3231 
3232   const int flat_size = output_shape.FlatSize();
3233   for (int i = 0; i < flat_size; ++i) {
3234     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3235                                                   params.float_activation_min,
3236                                                   params.float_activation_max);
3237   }
3238 
3239   return true;
3240 }
3241 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3242 inline bool AveragePool(const PoolParams& params,
3243                         const RuntimeShape& input_shape,
3244                         const uint8* input_data,
3245                         const RuntimeShape& output_shape, uint8* output_data) {
3246   ruy::profiler::ScopeLabel label("AveragePool/8bit");
3247 
3248   // Here, and in other pooling ops, in order to maintain locality of reference,
3249   // to minimize some recalculations, and to load into NEON vector registers, we
3250   // use an inner loop down the depth. Since depths can be large and hence we
3251   // would need arbitrarily large temporary storage, we divide the work up into
3252   // depth tranches just within the batch loop.
3253   static constexpr int kPoolingAccTrancheSize = 256;
3254 
3255   TFLITE_DCHECK_LE(params.quantized_activation_min,
3256                    params.quantized_activation_max);
3257   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3258   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3259   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3260   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3261   const int input_height = input_shape.Dims(1);
3262   const int input_width = input_shape.Dims(2);
3263   const int output_height = output_shape.Dims(1);
3264   const int output_width = output_shape.Dims(2);
3265   const int stride_height = params.stride_height;
3266   const int stride_width = params.stride_width;
3267 
3268   uint32 acc[kPoolingAccTrancheSize];
3269   for (int batch = 0; batch < batches; ++batch) {
3270     // We proceed through the depth in tranches (see comment above). The
3271     // depth_base is the depth at the beginning of the tranche. The
3272     // tranche_depth is the depth dimension of the tranche.
3273     for (int depth_base = 0; depth_base < depth;
3274          depth_base += kPoolingAccTrancheSize) {
3275       const int tranche_depth =
3276           std::min(depth - depth_base, kPoolingAccTrancheSize);
3277       for (int out_y = 0; out_y < output_height; ++out_y) {
3278         for (int out_x = 0; out_x < output_width; ++out_x) {
3279           const int in_x_origin =
3280               (out_x * stride_width) - params.padding_values.width;
3281           const int in_y_origin =
3282               (out_y * stride_height) - params.padding_values.height;
3283           const int filter_x_start = std::max(0, -in_x_origin);
3284           const int filter_x_end =
3285               std::min(params.filter_width, input_width - in_x_origin);
3286           const int filter_y_start = std::max(0, -in_y_origin);
3287           const int filter_y_end =
3288               std::min(params.filter_height, input_height - in_y_origin);
3289           const int filter_count =
3290               (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
3291           if (filter_count == 0) return false;
3292           memset(acc, 0, tranche_depth * sizeof(acc[0]));
3293           const uint8* input_ptr =
3294               input_data + depth_base +
3295               depth * (in_x_origin +
3296                        input_width * (in_y_origin + input_height * batch));
3297           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3298             const uint8* input_row_ptr =
3299                 input_ptr + depth * (fy * input_width + filter_x_start);
3300             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3301               const uint8* input_channel_ptr = input_row_ptr;
3302               int channel = 0;
3303 #ifdef USE_NEON
3304               for (; channel <= tranche_depth - 16; channel += 16) {
3305                 uint16x4_t acc_reg[4];
3306                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3307                 input_channel_ptr += 16;
3308                 acc_reg[0] = vget_low_u16(vmovl_u8(vget_low_u8(input_reg)));
3309                 acc_reg[1] = vget_high_u16(vmovl_u8(vget_low_u8(input_reg)));
3310                 acc_reg[2] = vget_low_u16(vmovl_u8(vget_high_u8(input_reg)));
3311                 acc_reg[3] = vget_high_u16(vmovl_u8(vget_high_u8(input_reg)));
3312                 for (int i = 0; i < 4; i++) {
3313                   vst1q_u32(
3314                       acc + channel + 4 * i,
3315                       vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3316                 }
3317               }
3318               for (; channel <= tranche_depth - 8; channel += 8) {
3319                 uint16x4_t acc_reg[2];
3320                 uint16x8_t input_reg = vmovl_u8(vld1_u8(input_channel_ptr));
3321                 input_channel_ptr += 8;
3322                 acc_reg[0] = vget_low_u16(input_reg);
3323                 acc_reg[1] = vget_high_u16(input_reg);
3324                 for (int i = 0; i < 2; i++) {
3325                   vst1q_u32(
3326                       acc + channel + 4 * i,
3327                       vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3328                 }
3329               }
3330 #endif
3331               for (; channel < tranche_depth; ++channel) {
3332                 acc[channel] += *input_channel_ptr++;
3333               }
3334               input_row_ptr += depth;
3335             }
3336           }
3337           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3338                                                    out_x, depth_base);
3339           int channel = 0;
3340 #ifdef USE_NEON
3341 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT)                               \
3342   if (filter_count == FILTER_COUNT) {                                   \
3343     for (; channel <= tranche_depth - 8; channel += 8) {                \
3344       uint16 buf[8];                                                    \
3345       for (int i = 0; i < 8; i++) {                                     \
3346         buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT;  \
3347       }                                                                 \
3348       uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));                      \
3349       buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
3350       buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
3351       vst1_u8(output_ptr + channel, buf8);                              \
3352     }                                                                   \
3353   }
3354           AVGPOOL_DIVIDING_BY(9)
3355           AVGPOOL_DIVIDING_BY(15)
3356 #undef AVGPOOL_DIVIDING_BY
3357           for (; channel <= tranche_depth - 8; channel += 8) {
3358             uint16 buf[8];
3359             for (int i = 0; i < 8; i++) {
3360               buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
3361             }
3362             uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
3363             buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
3364             buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
3365             vst1_u8(output_ptr + channel, buf8);
3366           }
3367 #endif
3368           for (; channel < tranche_depth; ++channel) {
3369             uint16 a = (acc[channel] + filter_count / 2) / filter_count;
3370             a = std::max<uint16>(a, params.quantized_activation_min);
3371             a = std::min<uint16>(a, params.quantized_activation_max);
3372             output_ptr[channel] = static_cast<uint8>(a);
3373           }
3374         }
3375       }
3376     }
3377   }
3378   return true;
3379 }
3380 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3381 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3382                     const float* input_data, const RuntimeShape& output_shape,
3383                     float* output_data) {
3384   ruy::profiler::ScopeLabel label("MaxPool");
3385   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3386   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3387   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3388   const int input_height = input_shape.Dims(1);
3389   const int input_width = input_shape.Dims(2);
3390   const int output_height = output_shape.Dims(1);
3391   const int output_width = output_shape.Dims(2);
3392   const int stride_height = params.stride_height;
3393   const int stride_width = params.stride_width;
3394 
3395   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3396   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3397   // Prefill the output to minimum representable float value
3398   out_mat.setConstant(std::numeric_limits<float>::lowest());
3399   for (int b = 0; b < batches; ++b) {
3400     for (int h = 0; h < input_height; ++h) {
3401       for (int w = 0; w < input_width; ++w) {
3402         // (h_start, h_end) * (w_start, w_end) is the range that the input
3403         // vector projects to.
3404         int hpad = h + params.padding_values.height;
3405         int wpad = w + params.padding_values.width;
3406         int h_start = (hpad < params.filter_height)
3407                           ? 0
3408                           : (hpad - params.filter_height) / stride_height + 1;
3409         int h_end = std::min(hpad / stride_height + 1, output_height);
3410         int w_start = (wpad < params.filter_width)
3411                           ? 0
3412                           : (wpad - params.filter_width) / stride_width + 1;
3413         int w_end = std::min(wpad / stride_width + 1, output_width);
3414         // compute elementwise sum
3415         for (int ph = h_start; ph < h_end; ++ph) {
3416           for (int pw = w_start; pw < w_end; ++pw) {
3417             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3418             out_mat.col(out_offset) =
3419                 out_mat.col(out_offset)
3420                     .cwiseMax(in_mat.col(
3421                         NodeOffset(b, h, w, input_height, input_width)));
3422           }
3423         }
3424       }
3425     }
3426   }
3427   const int flat_size = output_shape.FlatSize();
3428   for (int i = 0; i < flat_size; ++i) {
3429     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3430                                                   params.float_activation_min,
3431                                                   params.float_activation_max);
3432   }
3433 }
3434 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3435 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3436                     const uint8* input_data, const RuntimeShape& output_shape,
3437                     uint8* output_data) {
3438   ruy::profiler::ScopeLabel label("MaxPool/8bit");
3439 
3440   // Here, and in other pooling ops, in order to maintain locality of reference,
3441   // to minimize some recalculations, and to load into NEON vector registers, we
3442   // use an inner loop down the depth. Since depths can be large and hence we
3443   // would need arbitrarily large temporary storage, we divide the work up into
3444   // depth tranches just within the batch loop.
3445   static constexpr int kPoolingAccTrancheSize = 256;
3446 
3447   TFLITE_DCHECK_LE(params.quantized_activation_min,
3448                    params.quantized_activation_max);
3449   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3450   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3451   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3452   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3453   const int input_height = input_shape.Dims(1);
3454   const int input_width = input_shape.Dims(2);
3455   const int output_height = output_shape.Dims(1);
3456   const int output_width = output_shape.Dims(2);
3457   const int stride_height = params.stride_height;
3458   const int stride_width = params.stride_width;
3459 
3460   uint8 acc[kPoolingAccTrancheSize];
3461   for (int batch = 0; batch < batches; ++batch) {
3462     // We proceed through the depth in tranches (see comment above). The
3463     // depth_base is the depth at the beginning of the tranche. The
3464     // tranche_depth is the depth dimension of the tranche.
3465     for (int depth_base = 0; depth_base < depth;
3466          depth_base += kPoolingAccTrancheSize) {
3467       const int tranche_depth =
3468           std::min(depth - depth_base, kPoolingAccTrancheSize);
3469       for (int out_y = 0; out_y < output_height; ++out_y) {
3470         for (int out_x = 0; out_x < output_width; ++out_x) {
3471           const int in_x_origin =
3472               (out_x * stride_width) - params.padding_values.width;
3473           const int in_y_origin =
3474               (out_y * stride_height) - params.padding_values.height;
3475           const int filter_x_start = std::max(0, -in_x_origin);
3476           const int filter_x_end =
3477               std::min(params.filter_width, input_width - in_x_origin);
3478           const int filter_y_start = std::max(0, -in_y_origin);
3479           const int filter_y_end =
3480               std::min(params.filter_height, input_height - in_y_origin);
3481           memset(acc, 0, tranche_depth * sizeof(acc[0]));
3482           const uint8* input_ptr =
3483               input_data + depth_base +
3484               depth * (in_x_origin +
3485                        input_width * (in_y_origin + input_height * batch));
3486           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3487             const uint8* input_row_ptr =
3488                 input_ptr + depth * (fy * input_width + filter_x_start);
3489             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3490               const uint8* input_channel_ptr = input_row_ptr;
3491               int channel = 0;
3492 #ifdef USE_NEON
3493               for (; channel <= tranche_depth - 16; channel += 16) {
3494                 uint8x16_t acc_reg = vld1q_u8(acc + channel);
3495                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3496                 input_channel_ptr += 16;
3497                 acc_reg = vmaxq_u8(acc_reg, input_reg);
3498                 vst1q_u8(acc + channel, acc_reg);
3499               }
3500 
3501               for (; channel <= tranche_depth - 8; channel += 8) {
3502                 uint8x8_t acc_reg = vld1_u8(acc + channel);
3503                 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
3504                 input_channel_ptr += 8;
3505                 acc_reg = vmax_u8(acc_reg, input_reg);
3506                 vst1_u8(acc + channel, acc_reg);
3507               }
3508 #endif
3509               for (; channel < tranche_depth; ++channel) {
3510                 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
3511               }
3512               input_row_ptr += depth;
3513             }
3514           }
3515           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3516                                                    out_x, depth_base);
3517           int channel = 0;
3518 #ifdef USE_NEON
3519           for (; channel <= tranche_depth - 16; channel += 16) {
3520             uint8x16_t a = vld1q_u8(acc + channel);
3521             a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
3522             a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
3523             vst1q_u8(output_ptr + channel, a);
3524           }
3525           for (; channel <= tranche_depth - 8; channel += 8) {
3526             uint8x8_t a = vld1_u8(acc + channel);
3527             a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
3528             a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
3529             vst1_u8(output_ptr + channel, a);
3530           }
3531 #endif
3532           for (; channel < tranche_depth; ++channel) {
3533             uint8 a = acc[channel];
3534             a = std::max<uint8>(a, params.quantized_activation_min);
3535             a = std::min<uint8>(a, params.quantized_activation_max);
3536             output_ptr[channel] = static_cast<uint8>(a);
3537           }
3538         }
3539       }
3540     }
3541   }
3542 }
3543 
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3544 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
3545                    const float* input_data, const RuntimeShape& output_shape,
3546                    float* output_data) {
3547   ruy::profiler::ScopeLabel label("L2Pool");
3548   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3549   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3550   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3551   const int input_height = input_shape.Dims(1);
3552   const int input_width = input_shape.Dims(2);
3553   const int output_height = output_shape.Dims(1);
3554   const int output_width = output_shape.Dims(2);
3555   const int stride_height = params.stride_height;
3556   const int stride_width = params.stride_width;
3557   // Actually carry out L2 Pool. Code is written in forward mode: we go through
3558   // the input values once, and write to all the pooled regions that it maps to.
3559   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3560   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3561   Eigen::VectorXf in_square(in_mat.rows());
3562   Eigen::VectorXf out_count(out_mat.cols());
3563   out_count.setZero();
3564   // Prefill the output to 0.
3565   out_mat.setZero();
3566   for (int b = 0; b < batches; ++b) {
3567     for (int h = 0; h < input_height; ++h) {
3568       for (int w = 0; w < input_width; ++w) {
3569         // (h_start, h_end) * (w_start, w_end) is the range that the input
3570         // vector projects to.
3571         const int hpad = h + params.padding_values.height;
3572         const int wpad = w + params.padding_values.width;
3573         const int h_start =
3574             (hpad < params.filter_height)
3575                 ? 0
3576                 : (hpad - params.filter_height) / stride_height + 1;
3577         const int h_end = std::min(hpad / stride_height + 1, output_height);
3578         const int w_start =
3579             (wpad < params.filter_width)
3580                 ? 0
3581                 : (wpad - params.filter_width) / stride_width + 1;
3582         const int w_end = std::min(wpad / stride_width + 1, output_width);
3583         // pre-compute square
3584         const int in_offset = w + input_width * (h + input_height * b);
3585         in_square =
3586             in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
3587         // compute elementwise sum of squares
3588         for (int ph = h_start; ph < h_end; ++ph) {
3589           for (int pw = w_start; pw < w_end; ++pw) {
3590             const int out_offset = pw + output_width * (ph + output_height * b);
3591             out_mat.col(out_offset) += in_square;
3592             out_count(out_offset)++;
3593           }
3594         }
3595       }
3596     }
3597   }
3598 
3599   out_count = out_count.array().inverse();
3600   out_mat =
3601       (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
3602 
3603   const int flat_size = output_shape.FlatSize();
3604   for (int i = 0; i < flat_size; ++i) {
3605     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3606                                                   params.float_activation_min,
3607                                                   params.float_activation_max);
3608   }
3609 }
3610 
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3611 inline void LocalResponseNormalization(
3612     const tflite::LocalResponseNormalizationParams& op_params,
3613     const RuntimeShape& input_shape, const float* input_data,
3614     const RuntimeShape& output_shape, float* output_data) {
3615   ruy::profiler::ScopeLabel label("LocalResponseNormalization");
3616   MatchingFlatSize(input_shape, output_shape);
3617 
3618   const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3619   auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3620 
3621   // Carry out local response normalization, vector by vector.
3622   // Since the data are stored column major, making row-wise operation
3623   // probably not memory efficient anyway, we do an explicit for loop over
3624   // the columns.
3625   const int double_range = op_params.range * 2;
3626   Eigen::VectorXf padded_square(data_in.rows() + double_range);
3627   padded_square.setZero();
3628   const float bias = op_params.bias;
3629   for (int r = 0; r < data_in.cols(); ++r) {
3630     // Do local response normalization for data_in(:, r)
3631     // first, compute the square and store them in buffer for repeated use
3632     padded_square.block(op_params.range, 0, data_in.rows(), 1) =
3633         data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
3634     // Then, compute the scale and writes them to data_out
3635     float accumulated_scale = 0;
3636     for (int i = 0; i < double_range; ++i) {
3637       accumulated_scale += padded_square(i);
3638     }
3639     for (int i = 0; i < data_in.rows(); ++i) {
3640       accumulated_scale += padded_square(i + double_range);
3641       data_out(i, r) = bias + accumulated_scale;
3642       accumulated_scale -= padded_square(i);
3643     }
3644   }
3645 
3646   // In a few cases, the pow computation could benefit from speedups.
3647   if (op_params.beta == 1) {
3648     data_out.array() = data_in.array() * data_out.array().inverse();
3649   } else if (op_params.beta == 0.5f) {
3650     data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
3651   } else {
3652     data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
3653   }
3654 }
3655 
SoftmaxImpl(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,int start_batch,int end_batch)3656 inline void SoftmaxImpl(const SoftmaxParams& params,
3657                         const RuntimeShape& input_shape,
3658                         const float* input_data,
3659                         const RuntimeShape& output_shape, float* output_data,
3660                         int start_batch, int end_batch) {
3661   ruy::profiler::ScopeLabel label("Softmax/Impl");
3662   MatchingFlatSize(input_shape, output_shape);
3663 
3664   const int logit_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
3665   const MatrixMap<const float> in_mat(input_data + logit_size * start_batch,
3666                                       logit_size, end_batch - start_batch);
3667   MatrixMap<float> out_mat(output_data + logit_size * start_batch, logit_size,
3668                            end_batch - start_batch);
3669   // Compute the exponential first, removing the max coefficient for numerical
3670   // stability.
3671   out_mat =
3672       (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
3673   // We are separating out the exp function so that exp can be vectorized.
3674   out_mat = out_mat.array().exp();
3675   // Normalize to get the activations.
3676   Eigen::Array<float, 1, Eigen::Dynamic> scale =
3677       out_mat.array().colwise().sum().inverse();
3678   out_mat.array().rowwise() *= scale;
3679 }
3680 
3681 struct SoftmaxWorkerTask : cpu_backend_threadpool::Task {
SoftmaxWorkerTaskSoftmaxWorkerTask3682   SoftmaxWorkerTask(const SoftmaxParams& params,
3683                     const RuntimeShape& input_shape, const float* input_data,
3684                     const RuntimeShape& output_shape, float* output_data,
3685                     int start_batch, int end_batch)
3686       : params(params),
3687         input_shape(input_shape),
3688         input_data(input_data),
3689         output_shape(output_shape),
3690         output_data(output_data),
3691         start_batch(start_batch),
3692         end_batch(end_batch) {}
RunSoftmaxWorkerTask3693   void Run() override {
3694     SoftmaxImpl(params, input_shape, input_data, output_shape, output_data,
3695                 start_batch, end_batch);
3696   }
3697 
3698  private:
3699   const tflite::SoftmaxParams& params;
3700   const RuntimeShape& input_shape;
3701   const float* input_data;
3702   const RuntimeShape& output_shape;
3703   float* output_data;
3704   int start_batch;
3705   int end_batch;
3706 };
3707 
3708 inline void Softmax(const SoftmaxParams& params,
3709                     const RuntimeShape& input_shape, const float* input_data,
3710                     const RuntimeShape& output_shape, float* output_data,
3711                     CpuBackendContext* cpu_backend_context = nullptr) {
3712   ruy::profiler::ScopeLabel label("Softmax");
3713 
3714   // We picture softmax input as a 2-D matrix while the last dim is the logit
3715   // dim, and the rest dims will be the batch dim for the 2-D matrix.
3716   const int batch_size =
3717       FlatSizeSkipDim(input_shape, input_shape.DimensionsCount() - 1);
3718   constexpr int kMinBatchPerThread = 8;
3719   int thread_count = batch_size / kMinBatchPerThread;
3720   thread_count = thread_count > 0 ? thread_count : 1;
3721   const int capped_thread_count =
3722       cpu_backend_context == nullptr
3723           ? 1
3724           : std::min(thread_count, cpu_backend_context->max_num_threads());
3725   if (capped_thread_count == 1) {
3726     SoftmaxImpl(params, input_shape, input_data, output_shape, output_data, 0,
3727                 batch_size);
3728   } else {
3729     std::vector<SoftmaxWorkerTask> tasks;
3730     // TODO(b/131746020) don't create new heap allocations every time.
3731     // At least we make it a single heap allocation by using reserve().
3732     tasks.reserve(capped_thread_count);
3733     int batch_start = 0;
3734     for (int i = 0; i < capped_thread_count; ++i) {
3735       // Try to distribute the tasks as even as possible.
3736       int batch_end =
3737           batch_start + (batch_size - batch_start) / (capped_thread_count - i);
3738       tasks.emplace_back(params, input_shape, input_data, output_shape,
3739                          output_data, batch_start, batch_end);
3740       batch_start = batch_end;
3741     }
3742     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
3743                                     cpu_backend_context);
3744   }
3745 }
3746 
3747 template <typename T>
QuantizeSoftmaxOutput(float prob_rescaled,int32_t zero_point)3748 inline int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point) {
3749   const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
3750   return prob_rnd + zero_point;
3751 }
3752 
3753 #if !__aarch64__
3754 // With ARM64, rounding is faster than add + truncation.
3755 template <>
3756 inline int32_t QuantizeSoftmaxOutput<uint8_t>(float prob_rescaled,
3757                                               int32_t zero_point) {
3758   return static_cast<int32_t>(prob_rescaled + 0.5f);
3759 }
3760 #endif
3761 
PopulateSoftmaxLookupTable(SoftmaxParams * data,float input_scale,float beta)3762 inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
3763                                        float beta) {
3764   const float scale = -input_scale * beta;
3765   const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3766   for (int32_t val = 0; val <= max_uint8; ++val) {
3767     data->table[max_uint8 - val] = expf(scale * val);
3768   }
3769 }
3770 
3771 template <typename In, typename Out>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3772 inline void Softmax(const SoftmaxParams& params,
3773                     const RuntimeShape& input_shape, const In* input_data,
3774                     const RuntimeShape& output_shape, Out* output_data) {
3775   const int trailing_dim = input_shape.DimensionsCount() - 1;
3776   const int excluding_last_dim =
3777       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3778   const int last_dim =
3779       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3780 
3781   const int32_t clamp_max = std::numeric_limits<Out>::max();
3782   const int32_t clamp_min = std::numeric_limits<Out>::min();
3783   for (int i = 0; i < excluding_last_dim; ++i) {
3784     int32_t max_val = std::numeric_limits<In>::min();
3785     // Find max quantized value.
3786     for (int j = 0; j < last_dim; ++j) {
3787       max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
3788     }
3789 
3790     float sum_exp = 0.0f;
3791     const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3792     const float* table_offset = &params.table[max_uint8 - max_val];
3793     // Calculate normalizer sum(exp(x)).
3794     for (int j = 0; j < last_dim; ++j) {
3795       sum_exp += table_offset[input_data[j]];
3796     }
3797 
3798     const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3799     // Normalize and quantize probabilities.
3800     for (int j = 0; j < last_dim; ++j) {
3801       const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
3802       const int32_t prob_quantized =
3803           QuantizeSoftmaxOutput<Out>(prob_rescaled, params.zero_point);
3804       output_data[j] = static_cast<Out>(
3805           std::max(std::min(clamp_max, prob_quantized), clamp_min));
3806     }
3807     input_data += last_dim;
3808     output_data += last_dim;
3809   }
3810 }
3811 
3812 // Here's the softmax LUT optimization strategy:
3813 // For softmax, we can do some mathmetically equivalent transformation:
3814 //
3815 // softmax(x) = e^x / sum(e^x, 0...n)  ===> equals to
3816 // softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
3817 //
3818 // For quantization, `x` in our case is (input_q - input_zp) * input_s
3819 // For uint8 case (int8 can be handled similarly), the range is [0, 255]
3820 //
3821 // so if we let
3822 // CONST = (255 - input_zp) * input_s
3823 // then we will have:
3824 // softmax(x) = e^((input_q - 255) * input_s) --------- (1)
3825 //         /
3826 // sum(e^(input_q - 255) * input_s, 0...n)   -------- (2)
3827 //
3828 // the good thing about (1) is it's within the range of (0, 1), so we can
3829 // approximate its result with uint16.
3830 //  (1) = uint8_out * 1 / 2^16.
3831 //
3832 // so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
3833 // then (2) is essentially the following:
3834 // sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
3835 //
3836 // since (output_q - output_zp) * output_s = softmax(x)
3837 // output_q = lookup_uint8_table(input_zp)
3838 //            /
3839 // (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
3840 //             +
3841 //   output_zp
3842 //
3843 // We can actually further improve the performance by using uint8 instead of
3844 // uint16. But that we may lose some accuracy, so we need to pay attention
3845 // to that.
PopulateSoftmaxUInt8LookupTable(SoftmaxParams * data,float input_scale,float beta)3846 inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
3847                                             float input_scale, float beta) {
3848   const float scale = input_scale * beta;
3849   const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3850   const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
3851 
3852   for (int32_t val = 0; val <= max_uint8; ++val) {
3853     float input_to_exp = scale * (val - max_uint8);
3854     int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
3855     temp = std::min(max_uint16, temp);
3856     uint8_t part1 = temp >> 8;
3857     uint8_t part2 = temp & 0xff;
3858     data->uint8_table1[val] = static_cast<uint8_t>(part1);
3859     data->uint8_table2[val] = static_cast<uint8_t>(part2);
3860   }
3861 }
3862 
FindMaxValue(int size,const uint8_t * input_data,uint8_t offset)3863 inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
3864   int32_t max_val = std::numeric_limits<uint8_t>::min();
3865   int j = 0;
3866 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3867   uint8x16_t max_val_dup = vdupq_n_u8(max_val);
3868   uint8x16_t offset_dup = vdupq_n_u8(offset);
3869   for (; j <= size - 16; j += 16) {
3870     uint8x16_t input_value = vld1q_u8(input_data + j);
3871     input_value = veorq_u8(input_value, offset_dup);
3872     max_val_dup = vmaxq_u8(input_value, max_val_dup);
3873   }
3874   max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
3875 #endif
3876 
3877   for (; j < size; ++j) {
3878     max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
3879   }
3880   return max_val;
3881 }
3882 
3883 #ifdef USE_NEON
3884 // Value_to_store layout:
3885 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,int8_t * output)3886 inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
3887   const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
3888                                           vqmovn_s32(value_to_store.val[0]));
3889   const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
3890                                           vqmovn_s32(value_to_store.val[2]));
3891   const int8x16_t result =
3892       vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
3893   vst1q_s8(output, result);
3894 }
3895 
3896 // Value_to_store layout:
3897 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,uint8_t * output)3898 inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
3899   const uint16x8_t result_1 =
3900       vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
3901                    vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
3902   const uint16x8_t result_2 =
3903       vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
3904                    vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
3905   const uint8x16_t result =
3906       vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
3907   vst1q_u8(output, result);
3908 }
3909 
3910 #endif
3911 
3912 template <typename In, typename Out>
SoftmaxInt8LUT(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3913 inline void SoftmaxInt8LUT(const SoftmaxParams& params,
3914                            const RuntimeShape& input_shape,
3915                            const In* input_data,
3916                            const RuntimeShape& output_shape, Out* output_data) {
3917   const int trailing_dim = input_shape.DimensionsCount() - 1;
3918   const int excluding_last_dim =
3919       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3920   const int last_dim =
3921       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3922 
3923   const int32_t clamp_max = std::numeric_limits<Out>::max();
3924   const int32_t clamp_min = std::numeric_limits<Out>::min();
3925 
3926   // Offset is used to interpret the input data "correctly".
3927   // If the input is uint8, the data will be unchanged.
3928   // If the input is int8, since it will be reinterpret as uint8.
3929   // e.g.,
3930   // int8 127 will be applied "offset" to become 255 in uint8.
3931   uint8_t offset = 0;
3932   if (std::is_same<In, int8>::value) {
3933     offset = 0x80;
3934   }
3935 
3936   const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
3937 
3938 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3939   // This code uses ARM64-only instructions.
3940   // TODO(b/143709993): Port to ARMv7
3941 
3942   // Load the tables into registers. (4*4 128-bit registers)
3943   uint8x16x4_t table1[4];
3944   table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
3945   table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
3946   table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
3947   table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
3948 
3949   uint8x16x4_t table2[4];
3950   table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
3951   table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
3952   table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
3953   table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
3954 #endif
3955 
3956   for (int i = 0; i < excluding_last_dim; ++i) {
3957     // Find max quantized value.
3958     int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
3959 
3960     int32 sum_exp = 0;
3961     const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3962     const uint8_t table_offset = max_uint8 - max_val;
3963 
3964     // Calculate normalizer sum(exp(x)).
3965     int sum_j = 0;
3966 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3967     uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
3968     uint8x16_t offset_dup = vdupq_n_u8(offset);
3969     uint32x4_t sum_4 = vdupq_n_u32(0);
3970     const int multiplier_shift = 8;
3971     for (; sum_j <= last_dim - 16; sum_j += 16) {
3972       uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
3973       input_value = veorq_u8(input_value, offset_dup);
3974       input_value = vaddq_u8(input_value, table_offset_dup);
3975 
3976       const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3977       const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3978 
3979       uint16x8_t exp_value1 =
3980           vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3981       uint16x8_t exp_value2 =
3982           vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3983 
3984       exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3985       exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3986 
3987       sum_4 = vpadalq_u16(sum_4, exp_value1);
3988       sum_4 = vpadalq_u16(sum_4, exp_value2);
3989     }
3990     int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
3991                vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
3992     sum_exp += temp;
3993 
3994 #endif
3995     for (; sum_j < last_dim; ++sum_j) {
3996       const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
3997 
3998       uint8_t part1 = params.uint8_table1[index];
3999       uint8_t part2 = params.uint8_table2[index];
4000       sum_exp += ((part1 << 8) + part2);
4001     }
4002 
4003     const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
4004 
4005     int32 multiplier, shift;
4006     QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
4007 
4008     // Normalize and quantize probabilities.
4009     int j = 0;
4010 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
4011     const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
4012     const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
4013     const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
4014 
4015     for (; j <= last_dim - 16; j += 16) {
4016       uint8x16_t input_value = vld1q_u8(input_data_uint + j);
4017       input_value = veorq_u8(input_value, offset_dup);
4018       input_value = vaddq_u8(input_value, table_offset_dup);
4019 
4020       const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
4021       const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
4022 
4023       uint16x8_t exp_value1 =
4024           vshll_n_u8(vget_high_u8(output1), multiplier_shift);
4025       uint16x8_t exp_value2 =
4026           vshll_n_u8(vget_low_u8(output1), multiplier_shift);
4027 
4028       exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
4029       exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
4030 
4031       int32x4x4_t output_value;
4032       output_value.val[0] =
4033           vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
4034       output_value.val[1] =
4035           vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
4036       output_value.val[2] =
4037           vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
4038       output_value.val[3] =
4039           vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
4040 
4041       int32x4x4_t temp_val =
4042           MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
4043 
4044       temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
4045       temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
4046       temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
4047       temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
4048 
4049       temp_val.val[0] =
4050           vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
4051       temp_val.val[1] =
4052           vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
4053       temp_val.val[2] =
4054           vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
4055       temp_val.val[3] =
4056           vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
4057 
4058       StoreValue(temp_val, output_data + j);
4059     }
4060 #endif
4061     for (; j < last_dim; ++j) {
4062       const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
4063       const uint8_t part1 = params.uint8_table1[index];
4064       const uint8_t part2 = params.uint8_table2[index];
4065       const int32_t exp_value = (part1 << 8) + part2;
4066       const int32_t output_value =
4067           MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
4068 
4069       output_data[j] = static_cast<Out>(std::max(
4070           std::min(clamp_max, output_value + params.zero_point), clamp_min));
4071     }
4072     input_data_uint += last_dim;
4073     output_data += last_dim;
4074   }
4075 }
4076 
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4077 inline void LogSoftmax(const SoftmaxParams& params,
4078                        const RuntimeShape& input_shape, const float* input_data,
4079                        const RuntimeShape& output_shape, float* output_data) {
4080   ruy::profiler::ScopeLabel label("LogSoftmax");
4081   const int trailing_dim = input_shape.DimensionsCount() - 1;
4082   const int outer_size =
4083       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4084   const int depth =
4085       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4086 
4087   for (int i = 0; i < outer_size; ++i) {
4088     VectorMap<const float> block_input(input_data + i * depth, depth, 1);
4089     VectorMap<float> block_output(output_data + i * depth, depth, 1);
4090     // Find max element value which we'll use to ensure numerical stability
4091     // taking advantage of the following equality:
4092     // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
4093     const float max = block_input.maxCoeff();
4094     const float log_sum = std::log((block_input.array() - max).exp().sum());
4095     block_output = block_input.array() - max - log_sum;
4096   }
4097 }
4098 
4099 // Backwards compatibility. Less optimized than below version.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4100 inline void LogSoftmax(const SoftmaxParams& params,
4101                        const RuntimeShape& input_shape, const uint8* input_data,
4102                        const RuntimeShape& output_shape, uint8* output_data) {
4103   reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4104                             output_data);
4105 }
4106 
4107 // Compute LogSoftmax as (x - x_max) - ln(sum(e^(x_i - x_max)...)
4108 // as done in tf.nn.log_softmax to prevent underflow and overflow.
4109 // This is in contrast to just log(softmax(x))
4110 //
4111 // To handle quantization, first dequantize the inputs (from doing
4112 // e^(input scale * val) where we ignore the zero point since it cancels
4113 // out during subtraction due to the ln) and do a rescale at the end to int8.
4114 //
4115 // Notably this makes use of float and is intended as the optimized
4116 // form for quantized execution on CPU. For a fully integer version,
4117 // see the reference op.
4118 //
4119 // TODO(tflite): notes for optimization:
4120 // 1) See if e^ is also bottleneck in the reference fully-integer
4121 // version and apply lookup there and compare.
4122 template <typename T>
LogSoftmax(const SoftmaxParams & params,float input_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4123 inline void LogSoftmax(const SoftmaxParams& params, float input_scale,
4124                        const RuntimeShape& input_shape, const T* input_data,
4125                        const RuntimeShape& output_shape, T* output_data) {
4126   ruy::profiler::ScopeLabel label("LogSoftmax");
4127   const int trailing_dim = input_shape.DimensionsCount() - 1;
4128   const int excluding_last_dim =
4129       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4130   const int last_dim =
4131       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4132 
4133   const int32_t clamp_max = std::numeric_limits<T>::max();
4134   const int32_t clamp_min = std::numeric_limits<T>::min();
4135 
4136   for (int i = 0; i < excluding_last_dim; ++i) {
4137     T max_val = std::numeric_limits<T>::min();
4138     // Find max quantized value.
4139     for (int j = 0; j < last_dim; ++j) {
4140       max_val = std::max(max_val, input_data[j]);
4141     }
4142 
4143     float sum_exp = 0.0f;
4144     const int32_t max_uint8 = std::numeric_limits<uint8>::max();
4145     // Offset into table to compute exp(scale*(x - xmax)) instead of
4146     // exp(scale*(x)) to prevent overflow.
4147     const float* table_offset = &params.table[max_uint8 - max_val];
4148     // Calculate sum(exp(scale*(x - x_max))).
4149     for (int j = 0; j < last_dim; ++j) {
4150       sum_exp += table_offset[input_data[j]];
4151     }
4152     const float log_sum_exp = std::log(sum_exp);
4153 
4154     // params.scale is the output scale.
4155     const float scale = input_scale / params.scale;
4156     const float precomputed =
4157         (input_scale * max_val + log_sum_exp) / params.scale;
4158     for (int j = 0; j < last_dim; ++j) {
4159       // Equivalent to (input_scale * (input_data[j] - max_val) - log_sum_exp) /
4160       // output_scale.
4161       const float log_prob = scale * input_data[j] - precomputed;
4162 
4163       // TODO(tflite): look into better solution.
4164       // Use std::rint over std::round (which is used in
4165       // FakeQuant) since it's multiple times faster on tested arm32.
4166       const int32_t prob_quantized = std::rint(log_prob) + params.zero_point;
4167       output_data[j] = static_cast<T>(
4168           std::max(std::min(clamp_max, prob_quantized), clamp_min));
4169     }
4170     input_data += last_dim;
4171     output_data += last_dim;
4172   }
4173 }
4174 
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4175 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
4176                      const RuntimeShape& output_shape, float* output_data) {
4177   ruy::profiler::ScopeLabel label("Logistic");
4178   auto input_map = MapAsVector(input_data, input_shape);
4179   auto output_map = MapAsVector(output_data, output_shape);
4180   output_map.array() =
4181       input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
4182 }
4183 
4184 // Convenience version that allows, for example, generated-code calls to be
4185 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4186 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
4187                      const float* input_data, const RuntimeShape& output_shape,
4188                      float* output_data) {
4189   // Drop params: not needed.
4190   Logistic(input_shape, input_data, output_shape, output_data);
4191 }
4192 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4193 inline void Logistic(const LogisticParams& params,
4194                      const RuntimeShape& input_shape, const int16* input_data,
4195                      const RuntimeShape& output_shape, int16* output_data) {
4196   ruy::profiler::ScopeLabel label("Logistic/Int16");
4197   const int flat_size = MatchingFlatSize(input_shape, output_shape);
4198 
4199   for (int i = 0; i < flat_size; i++) {
4200   }
4201 
4202   int c = 0;
4203   const int16* input_data_ptr = input_data;
4204   int16* output_data_ptr = output_data;
4205 #ifdef GEMMLOWP_NEON
4206   {
4207     // F0 uses 0 integer bits, range [-1, 1].
4208     // This is the return type of math functions such as tanh, logistic,
4209     // whose range is in [-1, 1].
4210     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4211     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4212     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4213 
4214     for (; c <= flat_size - 16; c += 16) {
4215       F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4216       F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4217       F0 output0 = gemmlowp::logistic(input0);
4218       F0 output1 = gemmlowp::logistic(input1);
4219       vst1q_s16(output_data_ptr, output0.raw());
4220       vst1q_s16(output_data_ptr + 8, output1.raw());
4221 
4222       input_data_ptr += 16;
4223       output_data_ptr += 16;
4224     }
4225     for (; c <= flat_size - 8; c += 8) {
4226       F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4227       F0 output = gemmlowp::logistic(input);
4228       vst1q_s16(output_data_ptr, output.raw());
4229 
4230       input_data_ptr += 8;
4231       output_data_ptr += 8;
4232     }
4233   }
4234 #endif
4235 #ifdef GEMMLOWP_SSE4
4236   {
4237     // F0 uses 0 integer bits, range [-1, 1].
4238     // This is the return type of math functions such as tanh, logistic,
4239     // whose range is in [-1, 1].
4240     using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4241     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4242     using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4243 
4244     for (; c <= flat_size - 16; c += 16) {
4245       F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4246           _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4247       F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4248           reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4249       F0 output0 = gemmlowp::logistic(input0);
4250       F0 output1 = gemmlowp::logistic(input1);
4251       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4252                        output0.raw().v);
4253       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4254                        output1.raw().v);
4255       input_data_ptr += 16;
4256       output_data_ptr += 16;
4257     }
4258     for (; c <= flat_size - 8; c += 8) {
4259       F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4260           _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4261       F0 output = gemmlowp::logistic(input);
4262       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4263                        output.raw().v);
4264       input_data_ptr += 8;
4265       output_data_ptr += 8;
4266     }
4267   }
4268 #endif
4269 
4270   {
4271     // F0 uses 0 integer bits, range [-1, 1].
4272     // This is the return type of math functions such as tanh, logistic,
4273     // whose range is in [-1, 1].
4274     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4275     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4276     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4277 
4278     for (; c < flat_size; ++c) {
4279       F3 input = F3::FromRaw(*input_data_ptr);
4280       F0 output = gemmlowp::logistic(input);
4281       *output_data_ptr = output.raw();
4282 
4283       ++input_data_ptr;
4284       ++output_data_ptr;
4285     }
4286   }
4287 }
4288 
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4289 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
4290                  const RuntimeShape& output_shape, float* output_data) {
4291   ruy::profiler::ScopeLabel label("Tanh");
4292   auto input_map = MapAsVector(input_data, input_shape);
4293   auto output_map = MapAsVector(output_data, output_shape);
4294   output_map.array() = input_map.array().tanh();
4295 }
4296 
4297 // Convenience version that allows, for example, generated-code calls to be
4298 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4299 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
4300                  const float* input_data, const RuntimeShape& output_shape,
4301                  float* output_data) {
4302   // Drop params: not needed.
4303   Tanh(input_shape, input_data, output_shape, output_data);
4304 }
4305 
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4306 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4307                  const int16* input_data, const RuntimeShape& output_shape,
4308                  int16* output_data) {
4309   ruy::profiler::ScopeLabel label("Tanh/Int16");
4310   const int input_left_shift = params.input_left_shift;
4311   // Support for shifts is limited until we have a parameterized version of
4312   // SaturatingRoundingMultiplyByPOT().
4313   TFLITE_DCHECK_GE(input_left_shift, 0);
4314   TFLITE_DCHECK_LE(input_left_shift, 1);
4315 
4316   const int flat_size = MatchingFlatSize(input_shape, output_shape);
4317 
4318   int c = 0;
4319   const int16* input_data_ptr = input_data;
4320   int16* output_data_ptr = output_data;
4321 #ifdef GEMMLOWP_NEON
4322   {
4323     // F0 uses 0 integer bits, range [-1, 1].
4324     // This is the return type of math functions such as tanh, logistic,
4325     // whose range is in [-1, 1].
4326     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4327     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4328     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4329 
4330     if (input_left_shift == 0) {
4331       for (; c <= flat_size - 16; c += 16) {
4332         F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4333         F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4334         F0 output0 = gemmlowp::tanh(input0);
4335         F0 output1 = gemmlowp::tanh(input1);
4336         vst1q_s16(output_data_ptr, output0.raw());
4337         vst1q_s16(output_data_ptr + 8, output1.raw());
4338 
4339         input_data_ptr += 16;
4340         output_data_ptr += 16;
4341       }
4342       for (; c <= flat_size - 8; c += 8) {
4343         F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4344         F0 output = gemmlowp::tanh(input);
4345         vst1q_s16(output_data_ptr, output.raw());
4346 
4347         input_data_ptr += 8;
4348         output_data_ptr += 8;
4349       }
4350     } else {
4351       for (; c <= flat_size - 16; c += 16) {
4352         F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4353             vld1q_s16(input_data_ptr)));
4354         F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4355             vld1q_s16(input_data_ptr + 8)));
4356         F0 output0 = gemmlowp::tanh(input0);
4357         F0 output1 = gemmlowp::tanh(input1);
4358         vst1q_s16(output_data_ptr, output0.raw());
4359         vst1q_s16(output_data_ptr + 8, output1.raw());
4360 
4361         input_data_ptr += 16;
4362         output_data_ptr += 16;
4363       }
4364       for (; c <= flat_size - 8; c += 8) {
4365         F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4366             vld1q_s16(input_data_ptr)));
4367         F0 output = gemmlowp::tanh(input);
4368         vst1q_s16(output_data_ptr, output.raw());
4369 
4370         input_data_ptr += 8;
4371         output_data_ptr += 8;
4372       }
4373     }
4374   }
4375 #endif
4376 #ifdef GEMMLOWP_SSE4
4377   {
4378     // F0 uses 0 integer bits, range [-1, 1].
4379     // This is the return type of math functions such as tanh, logistic,
4380     // whose range is in [-1, 1].
4381     using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4382     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4383     using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4384 
4385     if (input_left_shift == 0) {
4386       for (; c <= flat_size - 16; c += 16) {
4387         F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4388             _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4389         F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4390             reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4391         F0 output0 = gemmlowp::tanh(input0);
4392         F0 output1 = gemmlowp::tanh(input1);
4393         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4394                          output0.raw().v);
4395         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4396                          output1.raw().v);
4397 
4398         input_data_ptr += 16;
4399         output_data_ptr += 16;
4400       }
4401       for (; c <= flat_size - 8; c += 8) {
4402         F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4403             _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4404         F0 output = gemmlowp::tanh(input);
4405         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4406                          output.raw().v);
4407         input_data_ptr += 8;
4408         output_data_ptr += 8;
4409       }
4410     } else {
4411       for (; c <= flat_size - 16; c += 16) {
4412         F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4413             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4414                 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4415         F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4416             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4417                 reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
4418         F0 output0 = gemmlowp::tanh(input0);
4419         F0 output1 = gemmlowp::tanh(input1);
4420         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4421                          output0.raw().v);
4422         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4423                          output1.raw().v);
4424 
4425         input_data_ptr += 16;
4426         output_data_ptr += 16;
4427       }
4428       for (; c <= flat_size - 8; c += 8) {
4429         F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4430             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4431                 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4432         F0 output = gemmlowp::tanh(input);
4433         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4434                          output.raw().v);
4435         input_data_ptr += 8;
4436         output_data_ptr += 8;
4437       }
4438     }
4439   }
4440 #endif
4441 
4442   {
4443     // F0 uses 0 integer bits, range [-1, 1].
4444     // This is the return type of math functions such as tanh, logistic,
4445     // whose range is in [-1, 1].
4446     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4447     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4448     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4449 
4450     if (input_left_shift == 0) {
4451       for (; c < flat_size; ++c) {
4452         F3 input = F3::FromRaw(*input_data_ptr);
4453         F0 output = gemmlowp::tanh(input);
4454         *output_data_ptr = output.raw();
4455 
4456         ++input_data_ptr;
4457         ++output_data_ptr;
4458       }
4459     } else {
4460       for (; c < flat_size; ++c) {
4461         F3 input = F3::FromRaw(
4462             gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
4463         F0 output = gemmlowp::tanh(input);
4464         *output_data_ptr = output.raw();
4465 
4466         ++input_data_ptr;
4467         ++output_data_ptr;
4468       }
4469     }
4470   }
4471 }
4472 
4473 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)4474 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
4475                  const RuntimeShape& output_shape, DstT* output_data) {
4476   ruy::profiler::ScopeLabel label("Cast");
4477   auto input_map = MapAsVector(input_data, input_shape);
4478   auto output_map = MapAsVector(output_data, output_shape);
4479   output_map.array() = input_map.array().template cast<DstT>();
4480 }
4481 
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4482 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
4483                   const RuntimeShape& output_shape, float* output_data) {
4484   ruy::profiler::ScopeLabel label("Floor");
4485   auto input_map = MapAsVector(input_data, input_shape);
4486   auto output_map = MapAsVector(output_data, output_shape);
4487   output_map.array() = Eigen::floor(input_map.array());
4488 }
4489 
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4490 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
4491                  const RuntimeShape& output_shape, float* output_data) {
4492   ruy::profiler::ScopeLabel label("Ceil");
4493   auto input_map = MapAsVector(input_data, input_shape);
4494   auto output_map = MapAsVector(output_data, output_shape);
4495   output_map.array() = Eigen::ceil(input_map.array());
4496 }
4497 
4498 // Helper methods for BatchToSpaceND.
4499 // `spatial_index_dim` specifies post-crop offset index in this spatial
4500 // dimension, i.e. spatial offset introduced by flattening batch to spatial
4501 // dimension minus the crop size at beginning. `block_shape_dim` is the block
4502 // size in current dimension. `input_dim` and `output_dim` are input and output
4503 // size of BatchToSpaceND operation in current dimension.
4504 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)4505 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
4506                           int input_dim, int output_dim, int* start_index,
4507                           int* end_index) {
4508   // (*start_index) * block_shape_dim is effectively rounded up to the next
4509   // multiple of block_shape_dim by the integer division.
4510   *start_index =
4511       std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4512   // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
4513   // end_index is exclusive).
4514   *end_index = std::min(
4515       input_dim,
4516       (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4517 }
4518 
4519 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)4520 inline void BatchToSpaceND(
4521     const RuntimeShape& unextended_input1_shape, const T* input1_data,
4522     const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
4523     const RuntimeShape& unextended_input3_shape, const int32* crops_data,
4524     const RuntimeShape& unextended_output_shape, T* output_data) {
4525   ruy::profiler::ScopeLabel label("BatchToSpaceND");
4526 
4527   TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
4528   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
4529   TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
4530                    unextended_output_shape.DimensionsCount());
4531 
4532   // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
4533   auto extend_shape = [](const RuntimeShape& shape) {
4534     if (shape.DimensionsCount() == 4) {
4535       return shape;
4536     }
4537     RuntimeShape new_shape(4, 1);
4538     new_shape.SetDim(0, shape.Dims(0));
4539     new_shape.SetDim(1, shape.Dims(1));
4540     new_shape.SetDim(3, shape.Dims(2));
4541     return new_shape;
4542   };
4543   const RuntimeShape input1_shape = extend_shape(unextended_input1_shape);
4544   const RuntimeShape output_shape = extend_shape(unextended_output_shape);
4545 
4546   const int output_width = output_shape.Dims(2);
4547   const int output_height = output_shape.Dims(1);
4548   const int output_batch_size = output_shape.Dims(0);
4549 
4550   const int depth = input1_shape.Dims(3);
4551   const int input_width = input1_shape.Dims(2);
4552   const int input_height = input1_shape.Dims(1);
4553   const int input_batch_size = input1_shape.Dims(0);
4554 
4555   const int block_shape_height = block_shape_data[0];
4556   const int block_shape_width =
4557       unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
4558   const int crops_top = crops_data[0];
4559   const int crops_left =
4560       unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
4561 
4562   for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
4563     const int out_batch = in_batch % output_batch_size;
4564     const int spatial_offset = in_batch / output_batch_size;
4565 
4566     int in_h_start = 0;
4567     int in_h_end = 0;
4568     // GetIndexRange ensures start and end indices are in [0, output_height).
4569     GetIndexRange(spatial_offset / block_shape_width - crops_top,
4570                   block_shape_height, input_height, output_height, &in_h_start,
4571                   &in_h_end);
4572 
4573     for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
4574       const int out_h = in_h * block_shape_height +
4575                         spatial_offset / block_shape_width - crops_top;
4576       TFLITE_DCHECK_GE(out_h, 0);
4577       TFLITE_DCHECK_LT(out_h, output_height);
4578 
4579       int in_w_start = 0;
4580       int in_w_end = 0;
4581       // GetIndexRange ensures start and end indices are in [0, output_width).
4582       GetIndexRange(spatial_offset % block_shape_width - crops_left,
4583                     block_shape_width, input_width, output_width, &in_w_start,
4584                     &in_w_end);
4585 
4586       for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
4587         const int out_w = in_w * block_shape_width +
4588                           spatial_offset % block_shape_width - crops_left;
4589         TFLITE_DCHECK_GE(out_w, 0);
4590         TFLITE_DCHECK_LT(out_w, output_width);
4591         T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
4592         const T* in =
4593             input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
4594         memcpy(out, in, depth * sizeof(T));
4595       }
4596     }
4597   }
4598 }
4599 
4600 template <typename T>
TypedMemset(void * ptr,T value,size_t num)4601 void TypedMemset(void* ptr, T value, size_t num) {
4602   // Optimization for common cases where memset() will suffice.
4603   if (value == 0 || std::is_same<T, uint8_t>::value) {
4604     memset(ptr, value, num * sizeof(T));
4605   } else {
4606     // Default implementation for cases where memset() will not preserve the
4607     // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
4608     char* pos = static_cast<char*>(ptr);
4609     for (size_t i = 0; i < num; ++i) {
4610       memcpy(pos, &value, sizeof(T));
4611       pos = pos + sizeof(T);
4612     }
4613   }
4614 }
4615 
4616 // This makes heavy use of Offset, along with conditional branches. There may be
4617 // opportunities for improvement.
4618 //
4619 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
4620 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
4621 // equivalent to a simple input1_data.  For Pad, it should point to a zero
4622 // value.
4623 //
4624 // Note that two typenames are required, so that T=P=int32 is considered a
4625 // specialization distinct from P=int32.
4626 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4627 inline void PadImpl(const tflite::PadParams& op_params,
4628                     const RuntimeShape& input_shape, const T* input_data,
4629                     const P* pad_value_ptr, const RuntimeShape& output_shape,
4630                     T* output_data) {
4631   ruy::profiler::ScopeLabel label("PadImpl");
4632   const int max_supported_dims = 5;
4633   const RuntimeShape ext_input_shape =
4634       RuntimeShape::ExtendedShape(max_supported_dims, input_shape);
4635   const RuntimeShape ext_output_shape =
4636       RuntimeShape::ExtendedShape(max_supported_dims, output_shape);
4637   TFLITE_DCHECK_LE(op_params.left_padding_count, max_supported_dims);
4638   TFLITE_DCHECK_LE(op_params.right_padding_count, max_supported_dims);
4639 
4640   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4641   // to 4 dims (yes, we are "padding the padding").
4642   std::vector<int> left_padding_copy(max_supported_dims, 0);
4643   const int left_padding_extend =
4644       max_supported_dims - op_params.left_padding_count;
4645   for (int i = 0; i < op_params.left_padding_count; ++i) {
4646     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4647   }
4648   std::vector<int> right_padding_copy(max_supported_dims, 0);
4649   const int right_padding_extend =
4650       max_supported_dims - op_params.right_padding_count;
4651   for (int i = 0; i < op_params.right_padding_count; ++i) {
4652     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4653   }
4654 
4655   const int output_batch = ext_output_shape.Dims(0);
4656   const int output_spatial_dim1 = ext_output_shape.Dims(1);
4657   const int output_spatial_dim2 = ext_output_shape.Dims(2);
4658   const int output_spatial_dim3 = ext_output_shape.Dims(3);
4659   const int output_channel = ext_output_shape.Dims(4);
4660 
4661   const int left_b_padding = left_padding_copy[0];
4662   const int left_s1_padding = left_padding_copy[1];
4663   const int left_s2_padding = left_padding_copy[2];
4664   const int left_s3_padding = left_padding_copy[3];
4665   const int left_c_padding = left_padding_copy[4];
4666 
4667   const int right_b_padding = right_padding_copy[0];
4668   const int right_s1_padding = right_padding_copy[1];
4669   const int right_s2_padding = right_padding_copy[2];
4670   const int right_s3_padding = right_padding_copy[3];
4671   const int right_c_padding = right_padding_copy[4];
4672 
4673   const int input_depth = ext_input_shape.Dims(4);
4674   const T pad_value = *pad_value_ptr;
4675 
4676   if (left_b_padding != 0) {
4677     TypedMemset<T>(output_data, pad_value,
4678                    left_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4679                        output_spatial_dim3 * output_channel);
4680   }
4681   for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
4682        ++out_b) {
4683     if (left_s1_padding != 0) {
4684       TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0, 0),
4685                      pad_value,
4686                      left_s1_padding * output_spatial_dim2 *
4687                          output_spatial_dim3 * output_channel);
4688     }
4689     for (int out_p = left_s1_padding;
4690          out_p < output_spatial_dim1 - right_s1_padding; ++out_p) {
4691       if (left_s2_padding != 0) {
4692         TypedMemset<T>(
4693             output_data + Offset(ext_output_shape, out_b, out_p, 0, 0, 0),
4694             pad_value, left_s2_padding * output_spatial_dim3 * output_channel);
4695       }
4696       for (int out_h = left_s2_padding;
4697            out_h < output_spatial_dim2 - right_s2_padding; ++out_h) {
4698         if (left_s3_padding != 0) {
4699           TypedMemset<T>(
4700               output_data + Offset(ext_output_shape, out_b, out_p, out_h, 0, 0),
4701               pad_value, left_s3_padding * output_channel);
4702         }
4703         for (int out_w = left_s3_padding;
4704              out_w < output_spatial_dim3 - right_s3_padding; ++out_w) {
4705           if (left_c_padding != 0) {
4706             TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_p,
4707                                                 out_h, out_w, 0),
4708                            pad_value, left_c_padding);
4709           }
4710 
4711           T* out = output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4712                                         out_w, left_c_padding);
4713           const T* in = input_data +
4714                         Offset(ext_input_shape, out_b - left_b_padding,
4715                                out_p - left_s1_padding, out_h - left_s2_padding,
4716                                out_w - left_s3_padding, 0);
4717           memcpy(out, in, input_depth * sizeof(T));
4718 
4719           if (right_c_padding != 0) {
4720             TypedMemset<T>(
4721                 output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4722                                      out_w, output_channel - right_c_padding),
4723                 pad_value, right_c_padding);
4724           }
4725         }
4726         if (right_s3_padding != 0) {
4727           TypedMemset<T>(
4728               output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4729                                    output_spatial_dim3 - right_s3_padding, 0),
4730               pad_value, right_s3_padding * output_channel);
4731         }
4732       }
4733       if (right_s2_padding != 0) {
4734         TypedMemset<T>(
4735             output_data + Offset(ext_output_shape, out_b, out_p,
4736                                  output_spatial_dim2 - right_s2_padding, 0, 0),
4737             pad_value, right_s2_padding * output_spatial_dim3 * output_channel);
4738       }
4739     }
4740     if (right_s1_padding != 0) {
4741       TypedMemset<T>(
4742           output_data + Offset(ext_output_shape, out_b,
4743                                output_spatial_dim1 - right_s1_padding, 0, 0, 0),
4744           pad_value,
4745           right_s1_padding * output_spatial_dim2 * output_spatial_dim3 *
4746               output_channel);
4747     }
4748   }
4749   if (right_b_padding != 0) {
4750     TypedMemset<T>(
4751         output_data + Offset(ext_output_shape, output_batch - right_b_padding,
4752                              0, 0, 0, 0),
4753         pad_value,
4754         right_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4755             output_spatial_dim3 * output_channel);
4756   }
4757 }
4758 
4759 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4760 inline void Pad(const tflite::PadParams& op_params,
4761                 const RuntimeShape& input_shape, const T* input_data,
4762                 const P* pad_value_ptr, const RuntimeShape& output_shape,
4763                 T* output_data) {
4764   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4765           output_data);
4766 }
4767 
4768 // The second (pad-value) input can be int32 when, say, the first is uint8.
4769 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4770 inline void Pad(const tflite::PadParams& op_params,
4771                 const RuntimeShape& input_shape, const T* input_data,
4772                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4773                 T* output_data) {
4774   const T converted_pad_value = static_cast<T>(*pad_value_ptr);
4775   PadImpl(op_params, input_shape, input_data, &converted_pad_value,
4776           output_shape, output_data);
4777 }
4778 
4779 // This version avoids conflicting template matching.
4780 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)4781 inline void Pad(const tflite::PadParams& op_params,
4782                 const RuntimeShape& input_shape, const int32* input_data,
4783                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4784                 int32* output_data) {
4785   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4786           output_data);
4787 }
4788 
4789 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
4790 //
4791 // This pad requires that (a) left and right paddings are in the 4D patterns
4792 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
4793 // T is uint8.
4794 //
4795 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
4796 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
4797 // equivalent to a simple input1_data.  For Pad, it should point to a zero
4798 // value.
4799 //
4800 // Note that two typenames are required, so that T=P=int32 is considered a
4801 // specialization distinct from P=int32.
4802 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4803 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
4804                                 const RuntimeShape& input_shape,
4805                                 const T* input_data, const P* pad_value_ptr,
4806                                 const RuntimeShape& output_shape,
4807                                 T* output_data) {
4808   ruy::profiler::ScopeLabel label("PadImageStyle");
4809   const RuntimeShape ext_input_shape =
4810       RuntimeShape::ExtendedShape(4, input_shape);
4811   const RuntimeShape ext_output_shape =
4812       RuntimeShape::ExtendedShape(4, output_shape);
4813   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
4814   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
4815 
4816   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4817   // to 4 dims (yes, we are "padding the padding").
4818   std::vector<int> left_padding_copy(4, 0);
4819   const int left_padding_extend = 4 - op_params.left_padding_count;
4820   for (int i = 0; i < op_params.left_padding_count; ++i) {
4821     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4822   }
4823   std::vector<int> right_padding_copy(4, 0);
4824   const int right_padding_extend = 4 - op_params.right_padding_count;
4825   for (int i = 0; i < op_params.right_padding_count; ++i) {
4826     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4827   }
4828   // The following padding restrictions are contractual requirements, and
4829   // embody what it means for a padding op to be "image-style".
4830   TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
4831   TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
4832   TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
4833   TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
4834 
4835   const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
4836   const int output_height = ext_output_shape.Dims(1);
4837   const int output_width = ext_output_shape.Dims(2);
4838   const int input_height = ext_input_shape.Dims(1);
4839   const int input_width = ext_input_shape.Dims(2);
4840   const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
4841 
4842   const int left_h_padding = left_padding_copy[1];
4843   const int left_w_padding = left_padding_copy[2];
4844   const int right_h_padding = right_padding_copy[1];
4845   const int right_w_padding = right_padding_copy[2];
4846 
4847   TFLITE_DCHECK_EQ(output_height,
4848                    input_height + left_h_padding + right_h_padding);
4849   TFLITE_DCHECK_EQ(output_width,
4850                    input_width + left_w_padding + right_w_padding);
4851 
4852   const T pad_value = *pad_value_ptr;
4853   const int top_block_size = left_h_padding * output_width * depth;
4854   const size_t num_top_block_bytes = top_block_size * sizeof(T);
4855   const int bottom_block_size = right_h_padding * output_width * depth;
4856   const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
4857   const int left_blocks_size = left_w_padding * depth;
4858   const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
4859   const int right_blocks_size = right_w_padding * depth;
4860   const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
4861   const int inner_line_size = input_width * depth;
4862   const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
4863 
4864   if (input_height == 0) {
4865     memset(output_data, pad_value,
4866            num_top_block_bytes + num_bottom_block_bytes);
4867   } else {
4868     for (int i = 0; i < batch; ++i) {
4869       // For each image in the batch, apply the top padding, then iterate
4870       // through rows, then apply the bottom padding.
4871       //
4872       // By unwinding one iteration, we can combine the first left-margin
4873       // padding with the top padding, and the last right-margin padding with
4874       // the bottom padding.
4875       memset(output_data, pad_value,
4876              num_top_block_bytes + num_left_block_bytes);
4877       output_data += top_block_size + left_blocks_size;
4878       memcpy(output_data, input_data, num_inner_line_bytes);
4879       input_data += inner_line_size;
4880       output_data += inner_line_size;
4881       // One iteration unwound.
4882       // Unwinding this loop affords the opportunity to reorder the loop work
4883       // and hence combine memset() calls.
4884       //
4885       // Before unwinding:
4886       // for (int j = 0; j < input_height; ++j) {
4887       //   // Pad on left, copy central data, pad on right.
4888       //   memset(output_data, pad_value, num_left_block_bytes);
4889       //   output_data += left_blocks_size;
4890       //   memcpy(output_data, input_data, num_inner_line_bytes);
4891       //   input_data += inner_line_size;
4892       //   output_data += inner_line_size;
4893       //   memset(output_data, pad_value, num_right_block_bytes);
4894       //   output_data += right_blocks_size;
4895       // }
4896       for (int j = 1; j < input_height; ++j) {
4897         memset(output_data, pad_value,
4898                num_right_block_bytes + num_left_block_bytes);
4899         output_data += right_blocks_size + left_blocks_size;
4900         memcpy(output_data, input_data, num_inner_line_bytes);
4901         input_data += inner_line_size;
4902         output_data += inner_line_size;
4903       }
4904       memset(output_data, pad_value,
4905              num_right_block_bytes + num_bottom_block_bytes);
4906       output_data += right_blocks_size + bottom_block_size;
4907     }
4908   }
4909 }
4910 
4911 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4912 inline void PadImageStyle(const tflite::PadParams& op_params,
4913                           const RuntimeShape& input_shape, const T* input_data,
4914                           const P* pad_value_ptr,
4915                           const RuntimeShape& output_shape, T* output_data) {
4916   reference_ops::PadImageStyle(op_params, input_shape, input_data,
4917                                pad_value_ptr, output_shape, output_data);
4918 }
4919 
4920 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)4921 inline void PadImageStyle(const tflite::PadParams& op_params,
4922                           const RuntimeShape& input_shape,
4923                           const uint8* input_data, const P* pad_value_ptr,
4924                           const RuntimeShape& output_shape,
4925                           uint8* output_data) {
4926   PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4927                       output_shape, output_data);
4928 }
4929 
4930 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)4931 inline void PadImageStyle(const tflite::PadParams& op_params,
4932                           const RuntimeShape& input_shape,
4933                           const float* input_data, const P* pad_value_ptr,
4934                           const RuntimeShape& output_shape,
4935                           float* output_data) {
4936   const float converted_pad_value = static_cast<float>(*pad_value_ptr);
4937   if (converted_pad_value == 0.0f) {
4938     PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4939                         output_shape, output_data);
4940   } else {
4941     PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4942             output_data);
4943   }
4944 }
4945 
4946 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)4947 inline void Slice(const tflite::SliceParams& op_params,
4948                   const RuntimeShape& input_shape,
4949                   const RuntimeShape& output_shape,
4950                   SequentialTensorWriter<T>* writer) {
4951   ruy::profiler::ScopeLabel label("Slice");
4952   const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
4953   TFLITE_DCHECK_LE(op_params.begin_count, 5);
4954   TFLITE_DCHECK_LE(op_params.size_count, 5);
4955   const int begin_count = op_params.begin_count;
4956   const int size_count = op_params.size_count;
4957   // We front-pad the begin and size vectors.
4958   std::array<int, 5> start;
4959   std::array<int, 5> stop;
4960   for (int i = 0; i < 5; ++i) {
4961     int padded_i = 5 - i;
4962     start[i] =
4963         begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
4964     stop[i] =
4965         (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
4966             ? ext_shape.Dims(i)
4967             : start[i] + op_params.size[size_count - padded_i];
4968   }
4969 
4970   for (int i0 = start[0]; i0 < stop[0]; ++i0) {
4971     for (int i1 = start[1]; i1 < stop[1]; ++i1) {
4972       for (int i2 = start[2]; i2 < stop[2]; ++i2) {
4973         for (int i3 = start[3]; i3 < stop[3]; ++i3) {
4974           const int len = stop[4] - start[4];
4975           if (len > 0)
4976             writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
4977         }
4978       }
4979     }
4980   }
4981 }
4982 
4983 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4984 inline void Slice(const tflite::SliceParams& op_params,
4985                   const RuntimeShape& input_shape, const T* input_data,
4986                   const RuntimeShape& output_shape, T* output_data) {
4987   SequentialTensorWriter<T> writer(input_data, output_data);
4988   return Slice(op_params, input_shape, output_shape, &writer);
4989 }
4990 
4991 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)4992 inline void Slice(const tflite::SliceParams& op_params,
4993                   const RuntimeShape& input_shape, const TfLiteTensor* input,
4994                   const RuntimeShape& output_shape, TfLiteTensor* output) {
4995   SequentialTensorWriter<T> writer(input, output);
4996   return Slice(op_params, input_shape, output_shape, &writer);
4997 }
4998 
4999 // Note: This implementation is only optimized for the case where the inner
5000 // stride == 1.
5001 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const RuntimeShape & unextended_output_shape,SequentialTensorWriter<T> * writer)5002 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5003                          const RuntimeShape& unextended_input_shape,
5004                          const RuntimeShape& unextended_output_shape,
5005                          SequentialTensorWriter<T>* writer) {
5006   using strided_slice::LoopCondition;
5007   using strided_slice::StartForAxis;
5008   using strided_slice::StopForAxis;
5009 
5010   ruy::profiler::ScopeLabel label("StridedSlice");
5011 
5012   // Note that the output_shape is not used herein.
5013   tflite::StridedSliceParams params_copy = op_params;
5014 
5015   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
5016   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
5017   const RuntimeShape input_shape =
5018       RuntimeShape::ExtendedShape(5, unextended_input_shape);
5019   const RuntimeShape output_shape =
5020       RuntimeShape::ExtendedShape(5, unextended_output_shape);
5021 
5022   // Reverse and pad to 5 dimensions because that is what the runtime code
5023   // requires (ie. all shapes must be 5D and are given backwards).
5024   strided_slice::StridedSlicePadIndices(&params_copy, 5);
5025 
5026   const int start_0 = StartForAxis(params_copy, input_shape, 0);
5027   const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
5028   const int start_1 = StartForAxis(params_copy, input_shape, 1);
5029   const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
5030   const int start_2 = StartForAxis(params_copy, input_shape, 2);
5031   const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
5032   const int start_3 = StartForAxis(params_copy, input_shape, 3);
5033   const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
5034   const int start_4 = StartForAxis(params_copy, input_shape, 4);
5035   const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
5036   const bool inner_stride_is_1 = params_copy.strides[4] == 1;
5037 
5038   for (int offset_0 = start_0 * input_shape.Dims(1),
5039            end_0 = stop_0 * input_shape.Dims(1),
5040            step_0 = params_copy.strides[0] * input_shape.Dims(1);
5041        !LoopCondition(offset_0, end_0, params_copy.strides[0]);
5042        offset_0 += step_0) {
5043     for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
5044              end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
5045              step_1 = params_copy.strides[1] * input_shape.Dims(2);
5046          !LoopCondition(offset_1, end_1, params_copy.strides[1]);
5047          offset_1 += step_1) {
5048       for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
5049                end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
5050                step_2 = params_copy.strides[2] * input_shape.Dims(3);
5051            !LoopCondition(offset_2, end_2, params_copy.strides[2]);
5052            offset_2 += step_2) {
5053         for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
5054                  end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
5055                  step_3 = params_copy.strides[3] * input_shape.Dims(4);
5056              !LoopCondition(offset_3, end_3, params_copy.strides[3]);
5057              offset_3 += step_3) {
5058           // When the stride is 1, the inner loop is equivalent to the
5059           // optimized slice inner loop. Otherwise, it is identical to the
5060           // strided_slice reference implementation inner loop.
5061           if (inner_stride_is_1) {
5062             const int len = stop_4 - start_4;
5063             if (len > 0) {
5064               writer->WriteN(offset_3 + start_4, len);
5065             }
5066           } else {
5067             for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
5068                  !LoopCondition(offset_4, end_4, params_copy.strides[4]);
5069                  offset_4 += params_copy.strides[4]) {
5070               writer->Write(offset_4);
5071             }
5072           }
5073         }
5074       }
5075     }
5076   }
5077 }
5078 
5079 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)5080 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5081                          const RuntimeShape& unextended_input_shape,
5082                          const T* input_data,
5083                          const RuntimeShape& unextended_output_shape,
5084                          T* output_data) {
5085   SequentialTensorWriter<T> writer(input_data, output_data);
5086   StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5087                   &writer);
5088 }
5089 
5090 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const TfLiteTensor * input,const RuntimeShape & unextended_output_shape,TfLiteTensor * output)5091 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5092                          const RuntimeShape& unextended_input_shape,
5093                          const TfLiteTensor* input,
5094                          const RuntimeShape& unextended_output_shape,
5095                          TfLiteTensor* output) {
5096   SequentialTensorWriter<T> writer(input, output);
5097   StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5098                   &writer);
5099 }
5100 
5101 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5102 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5103              const T* input2_data, const RuntimeShape& output_shape,
5104              T* output_data) {
5105   ruy::profiler::ScopeLabel label("TensorFlowMinimum");
5106   auto input1_map = MapAsVector(input1_data, input1_shape);
5107   auto output_map = MapAsVector(output_data, output_shape);
5108   auto min_value = input2_data[0];
5109   output_map.array() = input1_map.array().min(min_value);
5110 }
5111 
5112 // Convenience version that allows, for example, generated-code calls to be
5113 // the same as other binary ops.
5114 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5115 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5116                     const RuntimeShape&, const T* input2_data,
5117                     const RuntimeShape& output_shape, T* output_data) {
5118   // Drop shape of second input: not needed.
5119   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
5120 }
5121 
5122 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5123 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5124              const T* input2_data, const RuntimeShape& output_shape,
5125              T* output_data) {
5126   ruy::profiler::ScopeLabel label("TensorFlowMaximum");
5127   auto input1_map = MapAsVector(input1_data, input1_shape);
5128   auto output_map = MapAsVector(output_data, output_shape);
5129   auto max_value = input2_data[0];
5130   output_map.array() = input1_map.array().max(max_value);
5131 }
5132 
5133 // Convenience version that allows, for example, generated-code calls to be
5134 // the same as other binary ops.
5135 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5136 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5137                     const RuntimeShape&, const T* input2_data,
5138                     const RuntimeShape& output_shape, T* output_data) {
5139   // Drop shape of second input: not needed.
5140   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
5141 }
5142 
5143 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)5144 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
5145                      const RuntimeShape& input_shape, const T* input_data,
5146                      const RuntimeShape& filter_shape,
5147                      const RuntimeShape& output_shape, T* im2col_data) {
5148   ruy::profiler::ScopeLabel label("TransposeIm2col");
5149   const int stride_width = params.stride_width;
5150   const int stride_height = params.stride_height;
5151   const int pad_width = params.padding_values.width;
5152   const int pad_height = params.padding_values.height;
5153   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5154   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
5155   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
5156   TFLITE_DCHECK(im2col_data);
5157 
5158   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
5159   const int input_height = input_shape.Dims(1);
5160   const int input_width = input_shape.Dims(2);
5161   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
5162   const int filter_height = filter_shape.Dims(1);
5163   const int filter_width = filter_shape.Dims(2);
5164   const int output_height = output_shape.Dims(1);
5165   const int output_width = output_shape.Dims(2);
5166   MatchingDim(output_shape, 3, filter_shape, 0);  // output_depth
5167 
5168   // Construct the MxN sized im2col matrix.
5169   // The rows M, are sub-ordered B x H x W
5170   const RuntimeShape row_shape({1, batches, output_height, output_width});
5171   // The columns, N, are sub-ordered Kh x Kw x Din
5172   const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
5173   // Use dimensions M and N to construct dims for indexing directly into im2col
5174   const RuntimeShape im2col_shape(
5175       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
5176 
5177   // Build the im2col matrix by looping through all the input pixels,
5178   // computing their influence on the output, rather than looping through all
5179   // the output pixels. We therefore must initialize the im2col array to zero.
5180   // This is potentially inefficient because we subsequently overwrite bytes
5181   // set here. However, in practice memset is very fast and costs negligible.
5182   memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
5183 
5184   // Loop through the output batches
5185   for (int batch = 0; batch < batches; ++batch) {
5186     // Loop through input pixels one at a time.
5187     for (int in_y = 0; in_y < input_height; ++in_y) {
5188       for (int in_x = 0; in_x < input_width; ++in_x) {
5189         // Loop through the output pixels it will influence
5190         const int out_x_origin = (in_x * stride_width) - pad_width;
5191         const int out_y_origin = (in_y * stride_height) - pad_height;
5192         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
5193           const int out_y = out_y_origin + filter_y;
5194           // Is output pixel within height bounds?
5195           if ((out_y >= 0) && (out_y < output_height)) {
5196             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
5197               const int out_x = out_x_origin + filter_x;
5198               // Is output pixel within width bounds?
5199               if ((out_x >= 0) && (out_x < output_width)) {
5200                 // Copy the input elements of this pixel
5201                 T const* src =
5202                     input_data + Offset(input_shape, batch, in_y, in_x, 0);
5203                 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
5204                 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
5205                 T* dst = im2col_data +
5206                          Offset(im2col_shape, 0, 0, row_offset, col_offset);
5207                 memcpy(dst, src, input_depth * sizeof(T));
5208               }
5209             }
5210           }
5211         }
5212       }
5213     }
5214   }
5215 }
5216 
5217 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
5218 // order (height, width, depth), constructed from patches in 'col_data', which
5219 // is required to be in storage order (out_height * out_width, filter_height,
5220 // filter_width, in_depth).  Implementation by Yangqing Jia (jiayq).
5221 // Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
5222 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)5223 void Col2im(const T* col_data, const int depth, const int height,
5224             const int width, const int filter_h, const int filter_w,
5225             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
5226             const int stride_h, const int stride_w, T* im_data) {
5227   ruy::profiler::ScopeLabel label("Col2im");
5228   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
5229   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
5230   int h_pad = -pad_t;
5231   for (int h = 0; h < height_col; ++h) {
5232     int w_pad = -pad_l;
5233     for (int w = 0; w < width_col; ++w) {
5234       T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
5235       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
5236         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
5237           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
5238             // TODO(andydavis) Vectorize this loop (if compiler does not).
5239             for (int i = 0; i < depth; ++i) {
5240               im_patch_data[i] += col_data[i];
5241             }
5242           }
5243           im_patch_data += depth;
5244           col_data += depth;
5245         }
5246         // Jump over remaining number of depth.
5247         im_patch_data += depth * (width - filter_w);
5248       }
5249       w_pad += stride_w;
5250     }
5251     h_pad += stride_h;
5252   }
5253 }
5254 
5255 // TODO(b/188008864) Optimize this function by combining outer loops.
5256 template <typename T>
BiasAdd(T * im_data,const T * bias_data,const int batch_size,const int height,const int width,const int depth)5257 void BiasAdd(T* im_data, const T* bias_data, const int batch_size,
5258              const int height, const int width, const int depth) {
5259   if (bias_data) {
5260     for (int n = 0; n < batch_size; ++n) {
5261       for (int h = 0; h < height; ++h) {
5262         for (int w = 0; w < width; ++w) {
5263           for (int d = 0; d < depth; ++d) {
5264             im_data[d] += bias_data[d];
5265           }
5266           im_data += depth;
5267         }
5268       }
5269     }
5270   }
5271 }
5272 
5273 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)5274 inline void TransposeConvV2(
5275     const ConvParams& params, const RuntimeShape& input_shape,
5276     const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5277     const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5278     const float* bias_data, const RuntimeShape& output_shape,
5279     float* const output_data, const RuntimeShape& col2im_shape,
5280     float* col2im_data, CpuBackendContext* cpu_backend_context) {
5281   ruy::profiler::ScopeLabel label("TransposeConvV2/float");
5282   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5283   TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5284   TFLITE_DCHECK(col2im_data);
5285   TFLITE_DCHECK(hwoi_ordered_filter_data);
5286 
5287   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5288   const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5289   const int output_height = output_shape.Dims(1);
5290   const int output_width = output_shape.Dims(2);
5291   const int output_image_size = output_height * output_width;
5292   const int input_depth =
5293       MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5294   const int output_depth =
5295       MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5296   const int input_offset = input_image_size * input_depth;
5297   const int output_offset = output_image_size * output_depth;
5298 
5299   const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5300   const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5301   const int padding_top = params.padding_values.height;
5302   const int padding_bottom =
5303       params.padding_values.height + params.padding_values.height_offset;
5304   const int padding_left = params.padding_values.width;
5305   const int padding_right =
5306       params.padding_values.width + params.padding_values.width_offset;
5307   const int stride_height = params.stride_height;
5308   const int stride_width = params.stride_width;
5309 
5310   const int hwoi_ordered_filter_total_size =
5311       filter_height * filter_width * output_depth;
5312 
5313   cpu_backend_gemm::MatrixParams<float> lhs_params;
5314   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5315   lhs_params.rows = hwoi_ordered_filter_total_size;
5316   lhs_params.cols = input_depth;
5317   float* output_data_p = output_data;
5318   std::fill_n(output_data, output_offset * batch_size, 0.0f);
5319   for (int i = 0; i < batch_size; ++i) {
5320     cpu_backend_gemm::MatrixParams<float> rhs_params;
5321     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5322     rhs_params.rows = input_depth;
5323     rhs_params.cols = input_image_size;
5324     cpu_backend_gemm::MatrixParams<float> dst_params;
5325     dst_params.order = cpu_backend_gemm::Order::kColMajor;
5326     dst_params.rows = hwoi_ordered_filter_total_size;
5327     dst_params.cols = input_image_size;
5328     cpu_backend_gemm::GemmParams<float, float> gemm_params;
5329     cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5330                            input_data + input_offset * i, dst_params,
5331                            col2im_data, gemm_params, cpu_backend_context);
5332 
5333     Col2im(col2im_data, output_depth, output_height, output_width,
5334            filter_height, filter_width, padding_top, padding_left,
5335            padding_bottom, padding_right, stride_height, stride_width,
5336            output_data_p);
5337     output_data_p += output_offset;
5338   }
5339   output_data_p = output_data;
5340   BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width,
5341           output_depth);
5342 }
5343 
Quantize(int32_t multiplier,int32_t shift,int32_t total_size,int32_t output_zp,int32_t * scratch,uint8_t * output)5344 inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
5345                      int32_t output_zp, int32_t* scratch, uint8_t* output) {
5346   ruy::profiler::ScopeLabel label("Quantize/uint8");
5347   int i = 0;
5348   const int32_t output_min = std::numeric_limits<uint8_t>::min();
5349   const int32_t output_max = std::numeric_limits<uint8_t>::max();
5350 
5351 #ifdef USE_NEON
5352   const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
5353   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
5354   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
5355 
5356   using gemmlowp::RoundingDivideByPOT;
5357   using gemmlowp::SaturatingRoundingDoublingHighMul;
5358 
5359   for (; i <= total_size - 16; i += 16) {
5360     int32x4x4_t scratch_val;
5361     scratch_val.val[0] = vld1q_s32(scratch + i);
5362     scratch_val.val[1] = vld1q_s32(scratch + i + 4);
5363     scratch_val.val[2] = vld1q_s32(scratch + i + 8);
5364     scratch_val.val[3] = vld1q_s32(scratch + i + 12);
5365 
5366     int32x4x4_t temp_val =
5367         MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
5368 
5369     temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
5370     temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
5371     temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
5372     temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
5373 
5374     temp_val.val[0] =
5375         vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
5376     temp_val.val[1] =
5377         vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
5378     temp_val.val[2] =
5379         vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
5380     temp_val.val[3] =
5381         vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
5382 
5383     const uint16x8_t result_1 =
5384         vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[0])),
5385                      vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[1])));
5386     const uint16x8_t result_2 =
5387         vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[2])),
5388                      vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[3])));
5389     const uint8x16_t result =
5390         vcombine_u8(vqmovn_u16(result_1), vqmovn_u16(result_2));
5391     vst1q_u8(output + i, result);
5392   }
5393 #endif
5394   for (; i < total_size; ++i) {
5395     int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
5396     temp += output_zp;
5397     if (temp > output_max) {
5398       temp = output_max;
5399     }
5400     if (temp < output_min) {
5401       temp = output_min;
5402     }
5403     output[i] = static_cast<uint8_t>(temp);
5404   }
5405 }
5406 
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5407 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5408                      int32_t channel_size, int32_t total_size,
5409                      int32_t output_zp, int32_t output_min, int32_t output_max,
5410                      int32_t* scratch, int8_t* output) {
5411   ruy::profiler::ScopeLabel label("Quantize/int8");
5412 
5413   // Here we're trying to quantize the raw accumulators:
5414   //        output_channels
5415   //       data data data data data
5416   // rows  data data data data data
5417   //       data data data data data
5418   //          ....
5419   //
5420   // In order to minimize the reload of the multipliers & shifts, once we load
5421   // the multipliers & shifts, we load & quantize the raw accumulators for every
5422   // row.
5423 #ifdef USE_NEON
5424   const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5425   const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5426   const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5427   const int32x4_t zeros = vdupq_n_s32(0);
5428 #endif
5429 
5430   TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5431   const int32_t rows = total_size / channel_size;
5432 
5433   int c = 0;
5434 
5435 #ifdef USE_NEON
5436   using gemmlowp::RoundingDivideByPOT;
5437   for (; c <= channel_size - 8; c += 8) {
5438     int32x4_t out_shift_1 = vld1q_s32(shift + c);
5439     int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5440     int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5441     int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5442 
5443     // Right shift will be performed as left shift with negative values.
5444     int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5445     int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5446 
5447     int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5448     int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5449     for (int n = 0; n < rows; ++n) {
5450       int loc = n * channel_size + c;
5451       int32x4_t acc_1 = vld1q_s32(scratch + loc);
5452       int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5453 
5454       // Saturating Rounding Doubling High Mul.
5455       acc_1 = vshlq_s32(acc_1, left_shift_1);
5456       acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5457       acc_2 = vshlq_s32(acc_2, left_shift_2);
5458       acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5459 
5460       // Rounding Dividing By POT.
5461       acc_1 = vrshlq_s32(acc_1, right_shift_1);
5462       acc_2 = vrshlq_s32(acc_2, right_shift_2);
5463 
5464       // Add the output offset.
5465       acc_1 = vaddq_s32(acc_1, output_offset_vec);
5466       acc_2 = vaddq_s32(acc_2, output_offset_vec);
5467 
5468       // Apply the activation function.
5469       acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5470       acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5471       acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5472       acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5473 
5474       // Saturating cast to int8 and store to destination.
5475       const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5476       const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5477       const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5478       const int8x8_t res_s8 = vqmovn_s16(res_s16);
5479       vst1_s8(output + loc, res_s8);
5480     }
5481   }
5482 
5483 #endif  // USE_NEON
5484   // Handle leftover values, one by one. This is very slow.
5485   for (; c < channel_size; c++) {
5486     for (int n = 0; n < rows; ++n) {
5487       int loc = n * channel_size + c;
5488       int32 acc = scratch[loc];
5489       acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5490       acc += output_zp;
5491       acc = std::max(acc, output_min);
5492       acc = std::min(acc, output_max);
5493       output[loc] = static_cast<int8>(acc);
5494     }
5495   }
5496 }
5497 
5498 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const uint8_t * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8_t * output_data,const RuntimeShape & col2im_shape,int32_t * col2im_data,int32_t * scratch_data,CpuBackendContext * cpu_backend_context)5499 inline void TransposeConvV2(
5500     const ConvParams& params, const RuntimeShape& input_shape,
5501     const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5502     const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5503     const int32* bias_data, const RuntimeShape& output_shape,
5504     uint8_t* output_data, const RuntimeShape& col2im_shape,
5505     int32_t* col2im_data, int32_t* scratch_data,
5506     CpuBackendContext* cpu_backend_context) {
5507   ruy::profiler::ScopeLabel label("TransposeConvV2/uint8");
5508   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5509   TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5510   TFLITE_DCHECK(col2im_data);
5511   TFLITE_DCHECK(hwoi_ordered_filter_data);
5512 
5513   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5514   const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5515   const int output_height = output_shape.Dims(1);
5516   const int output_width = output_shape.Dims(2);
5517   const int output_image_size = output_height * output_width;
5518   const int input_depth =
5519       MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5520   const int output_depth =
5521       MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5522   const int input_offset = input_image_size * input_depth;
5523   const int output_offset = output_image_size * output_depth;
5524 
5525   const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5526   const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5527   const int padding_top = params.padding_values.height;
5528   const int padding_bottom =
5529       params.padding_values.height + params.padding_values.height_offset;
5530   const int padding_left = params.padding_values.width;
5531   const int padding_right =
5532       params.padding_values.width + params.padding_values.width_offset;
5533   const int stride_height = params.stride_height;
5534   const int stride_width = params.stride_width;
5535 
5536   const int hwoi_ordered_filter_total_size =
5537       filter_height * filter_width * output_depth;
5538 
5539   cpu_backend_gemm::MatrixParams<uint8_t> lhs_params;
5540   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5541   lhs_params.rows = hwoi_ordered_filter_total_size;
5542   lhs_params.cols = input_depth;
5543   lhs_params.zero_point = -params.weights_offset;
5544 
5545   int32_t* scratch_data_p = scratch_data;
5546   std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
5547   for (int i = 0; i < batch_size; ++i) {
5548     cpu_backend_gemm::MatrixParams<uint8_t> rhs_params;
5549     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5550     rhs_params.rows = input_depth;
5551     rhs_params.cols = input_image_size;
5552     rhs_params.zero_point = -params.input_offset;
5553 
5554     cpu_backend_gemm::MatrixParams<int32_t> dst_params;
5555     dst_params.order = cpu_backend_gemm::Order::kColMajor;
5556     dst_params.rows = hwoi_ordered_filter_total_size;
5557     dst_params.cols = input_image_size;
5558 
5559     cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
5560     cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5561                            input_data + input_offset * i, dst_params,
5562                            col2im_data, gemm_params, cpu_backend_context);
5563 
5564     Col2im(col2im_data, output_depth, output_height, output_width,
5565            filter_height, filter_width, padding_top, padding_left,
5566            padding_bottom, padding_right, stride_height, stride_width,
5567            scratch_data_p);
5568 
5569     scratch_data_p += output_offset;
5570   }
5571   scratch_data_p = scratch_data;
5572   BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width,
5573           output_depth);
5574 
5575   Quantize(params.output_multiplier, params.output_shift,
5576            output_shape.FlatSize(), params.output_offset, scratch_data,
5577            output_data);
5578 }
5579 
5580 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
5581 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
5582 // reference version. Debug checks are in place to test if this occurs.
5583 // NOTE: If align_corners or half_pixel_centers is true, we use the reference
5584 // version.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)5585 inline void ResizeNearestNeighbor(
5586     const tflite::ResizeNearestNeighborParams& op_params,
5587     const RuntimeShape& unextended_input_shape, const uint8* input_data,
5588     const RuntimeShape& output_size_shape, const int32* output_size_data,
5589     const RuntimeShape& unextended_output_shape, uint8* output_data) {
5590   if (op_params.align_corners || op_params.half_pixel_centers) {
5591     // TODO(b/149823713): Add support for align_corners & half_pixel_centers in
5592     // this kernel.
5593     reference_ops::ResizeNearestNeighbor(
5594         op_params, unextended_input_shape, input_data, output_size_shape,
5595         output_size_data, unextended_output_shape, output_data);
5596     return;
5597   }
5598   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5599   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5600 
5601   const RuntimeShape input_shape =
5602       RuntimeShape::ExtendedShape(4, unextended_input_shape);
5603   const RuntimeShape output_shape =
5604       RuntimeShape::ExtendedShape(4, unextended_output_shape);
5605 
5606   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5607   int32 input_height = input_shape.Dims(1);
5608   int32 input_width = input_shape.Dims(2);
5609   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5610 
5611   // The Tensorflow version of this op allows resize on the width and height
5612   // axis only.
5613   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5614   int32 output_height = output_size_data[0];
5615   int32 output_width = output_size_data[1];
5616 
5617   // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
5618   // error factor and to avoid zero scales. For example, with input_height = 1,
5619   // output_height = 3, the float scaling factor would be non-zero at 1/3.
5620   // With fixed-point, this is zero.
5621   int32 height_scale = (input_height << 16) / output_height + 1;
5622   int32 width_scale = (input_width << 16) / output_width + 1;
5623 
5624   const int col_offset = input_shape.Dims(3);
5625   const int row_offset = input_shape.Dims(2) * col_offset;
5626   const int batch_offset = input_shape.Dims(1) * row_offset;
5627 
5628   const uint8* input_ptr = input_data;
5629   uint8* output_ptr = output_data;
5630   for (int b = 0; b < batches; ++b) {
5631     for (int y = 0; y < output_height; ++y) {
5632       int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
5633       // Check offset calculation is the same as the reference version. See
5634       // function comment for details. We check using a non-float version of:
5635       // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
5636       //                                            / output_height)));
5637       TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
5638       TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
5639       const uint8* y_input_ptr = input_ptr + in_y * row_offset;
5640       for (int x = 0; x < output_width; ++x) {
5641         int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
5642         // Check offset calculation is the same as the reference version. See
5643         // function comment for details. We check using a non-float version of:
5644         // TFLITE_DCHECK_EQ(in_y,
5645         //                  std::floor(y * (static_cast<float>(input_width)
5646         //                                      / output_width)));
5647         TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
5648         TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
5649         const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
5650         memcpy(output_ptr, x_input_ptr, depth);
5651         output_ptr += depth;
5652       }
5653     }
5654     input_ptr += batch_offset;
5655   }
5656 }
5657 
5658 template <typename input_type, typename output_type>
Requantize(const input_type * input_data,int32_t size,int32_t effective_scale_multiplier,int32_t effective_scale_shift,int32_t input_zeropoint,int32_t output_zeropoint,output_type * output_data)5659 inline void Requantize(const input_type* input_data, int32_t size,
5660                        int32_t effective_scale_multiplier,
5661                        int32_t effective_scale_shift, int32_t input_zeropoint,
5662                        int32_t output_zeropoint, output_type* output_data) {
5663   reference_ops::Requantize(input_data, size, effective_scale_multiplier,
5664                             effective_scale_shift, input_zeropoint,
5665                             output_zeropoint, output_data);
5666 }
5667 
5668 template <>
5669 inline void Requantize<int8_t, uint8_t>(const int8_t* input_data, int32_t size,
5670                                         int32_t effective_scale_multiplier,
5671                                         int32_t effective_scale_shift,
5672                                         int32_t input_zeropoint,
5673                                         int32_t output_zeropoint,
5674                                         uint8_t* output_data) {
5675   ruy::profiler::ScopeLabel label("Requantize/Int8ToUint8");
5676 
5677   static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5678   static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5679 
5680   int i = 0;
5681 #ifdef USE_NEON
5682   // Constants.
5683   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5684   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5685   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5686   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5687 
5688   for (; i <= size - 16; i += 16) {
5689     const int8x16_t input_vec = vld1q_s8(input_data + i);
5690     const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5691     const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5692     int32x4x4_t input;
5693     input.val[0] = vmovl_s16(vget_low_s16(first_half));
5694     input.val[1] = vmovl_s16(vget_high_s16(first_half));
5695     input.val[2] = vmovl_s16(vget_low_s16(second_half));
5696     input.val[3] = vmovl_s16(vget_high_s16(second_half));
5697     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5698     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5699     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5700     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5701 
5702     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5703         input, effective_scale_multiplier, effective_scale_shift);
5704 
5705     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5706     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5707     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5708     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5709     result.val[0] =
5710         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5711     result.val[1] =
5712         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5713     result.val[2] =
5714         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5715     result.val[3] =
5716         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5717 
5718     const uint32x4_t result_val_1_unsigned =
5719         vreinterpretq_u32_s32(result.val[0]);
5720     const uint32x4_t result_val_2_unsigned =
5721         vreinterpretq_u32_s32(result.val[1]);
5722     const uint32x4_t result_val_3_unsigned =
5723         vreinterpretq_u32_s32(result.val[2]);
5724     const uint32x4_t result_val_4_unsigned =
5725         vreinterpretq_u32_s32(result.val[3]);
5726 
5727     const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5728     const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5729     const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5730     const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5731     const uint16x8_t output_first_half =
5732         vcombine_u16(narrowed_val_1, narrowed_val_2);
5733     const uint16x8_t output_second_half =
5734         vcombine_u16(narrowed_val_3, narrowed_val_4);
5735     const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5736     const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5737     const uint8x16_t narrowed_result =
5738         vcombine_u8(narrowed_first_half, narrowed_second_half);
5739     vst1q_u8(output_data + i, narrowed_result);
5740   }
5741 
5742 #endif
5743   for (; i < size; ++i) {
5744     const int32_t input = input_data[i] - input_zeropoint;
5745     const int32_t output =
5746         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5747                                       effective_scale_shift) +
5748         output_zeropoint;
5749     const int32_t clamped_output =
5750         std::max(std::min(output, kMaxOutput), kMinOutput);
5751     output_data[i] = static_cast<uint8_t>(clamped_output);
5752   }
5753 }
5754 
5755 template <>
5756 inline void Requantize<uint8_t, int8_t>(const uint8_t* input_data, int32_t size,
5757                                         int32_t effective_scale_multiplier,
5758                                         int32_t effective_scale_shift,
5759                                         int32_t input_zeropoint,
5760                                         int32_t output_zeropoint,
5761                                         int8_t* output_data) {
5762   ruy::profiler::ScopeLabel label("Requantize/Uint8ToInt8");
5763 
5764   static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5765   static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5766 
5767   int i = 0;
5768 #ifdef USE_NEON
5769   // Constants.
5770   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5771   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5772   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5773   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5774 
5775   for (; i <= size - 16; i += 16) {
5776     const uint8x16_t input_vec = vld1q_u8(input_data + i);
5777     const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5778     const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5779     int32x4x4_t input;
5780     input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5781     input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5782     input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5783     input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5784     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5785     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5786     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5787     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5788 
5789     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5790         input, effective_scale_multiplier, effective_scale_shift);
5791 
5792     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5793     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5794     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5795     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5796     result.val[0] =
5797         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5798     result.val[1] =
5799         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5800     result.val[2] =
5801         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5802     result.val[3] =
5803         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5804 
5805     const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5806     const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5807     const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5808     const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5809     const int16x8_t output_first_half =
5810         vcombine_s16(narrowed_val_1, narrowed_val_2);
5811     const int16x8_t output_second_half =
5812         vcombine_s16(narrowed_val_3, narrowed_val_4);
5813     const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5814     const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5815     const int8x16_t narrowed_result =
5816         vcombine_s8(narrowed_first_half, narrowed_second_half);
5817     vst1q_s8(output_data + i, narrowed_result);
5818   }
5819 
5820 #endif
5821   for (; i < size; ++i) {
5822     const int32_t input = input_data[i] - input_zeropoint;
5823     const int32_t output =
5824         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5825                                       effective_scale_shift) +
5826         output_zeropoint;
5827     const int32_t clamped_output =
5828         std::max(std::min(output, kMaxOutput), kMinOutput);
5829     output_data[i] = static_cast<int8_t>(clamped_output);
5830   }
5831 }
5832 
5833 template <>
5834 inline void Requantize<int8_t, int8_t>(const int8_t* input_data, int32_t size,
5835                                        int32_t effective_scale_multiplier,
5836                                        int32_t effective_scale_shift,
5837                                        int32_t input_zeropoint,
5838                                        int32_t output_zeropoint,
5839                                        int8_t* output_data) {
5840   ruy::profiler::ScopeLabel label("Requantize/Int8ToInt8");
5841 
5842   static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5843   static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5844 
5845   int i = 0;
5846 #ifdef USE_NEON
5847   // Constants.
5848   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5849   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5850   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5851   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5852 
5853   for (; i <= size - 16; i += 16) {
5854     const int8x16_t input_vec = vld1q_s8(input_data + i);
5855     const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5856     const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5857     int32x4x4_t input;
5858     input.val[0] = vmovl_s16(vget_low_s16(first_half));
5859     input.val[1] = vmovl_s16(vget_high_s16(first_half));
5860     input.val[2] = vmovl_s16(vget_low_s16(second_half));
5861     input.val[3] = vmovl_s16(vget_high_s16(second_half));
5862 
5863     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5864     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5865     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5866     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5867 
5868     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5869         input, effective_scale_multiplier, effective_scale_shift);
5870 
5871     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5872     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5873     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5874     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5875     result.val[0] =
5876         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5877     result.val[1] =
5878         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5879     result.val[2] =
5880         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5881     result.val[3] =
5882         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5883 
5884     const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5885     const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5886     const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5887     const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5888     const int16x8_t output_first_half =
5889         vcombine_s16(narrowed_val_1, narrowed_val_2);
5890     const int16x8_t output_second_half =
5891         vcombine_s16(narrowed_val_3, narrowed_val_4);
5892     const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5893     const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5894     const int8x16_t narrowed_result =
5895         vcombine_s8(narrowed_first_half, narrowed_second_half);
5896     vst1q_s8(output_data + i, narrowed_result);
5897   }
5898 
5899 #endif
5900   for (; i < size; ++i) {
5901     const int32_t input = input_data[i] - input_zeropoint;
5902     const int32_t output =
5903         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5904                                       effective_scale_shift) +
5905         output_zeropoint;
5906     const int32_t clamped_output =
5907         std::max(std::min(output, kMaxOutput), kMinOutput);
5908     output_data[i] = static_cast<int8_t>(clamped_output);
5909   }
5910 }
5911 
5912 template <>
5913 inline void Requantize<uint8_t, uint8_t>(
5914     const uint8_t* input_data, int32_t size, int32_t effective_scale_multiplier,
5915     int32_t effective_scale_shift, int32_t input_zeropoint,
5916     int32_t output_zeropoint, uint8_t* output_data) {
5917   ruy::profiler::ScopeLabel label("Requantize/Uint8ToUint8");
5918 
5919   static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5920   static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5921 
5922   int i = 0;
5923 #ifdef USE_NEON
5924   // Constants.
5925   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5926   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5927   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5928   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5929 
5930   for (; i <= size - 16; i += 16) {
5931     const uint8x16_t input_vec = vld1q_u8(input_data + i);
5932     const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5933     const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5934     int32x4x4_t input;
5935     input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5936     input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5937     input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5938     input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5939     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5940     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5941     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5942     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5943 
5944     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5945         input, effective_scale_multiplier, effective_scale_shift);
5946 
5947     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5948     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5949     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5950     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5951     result.val[0] =
5952         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5953     result.val[1] =
5954         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5955     result.val[2] =
5956         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5957     result.val[3] =
5958         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5959 
5960     const uint32x4_t result_val_1_unsigned =
5961         vreinterpretq_u32_s32(result.val[0]);
5962     const uint32x4_t result_val_2_unsigned =
5963         vreinterpretq_u32_s32(result.val[1]);
5964     const uint32x4_t result_val_3_unsigned =
5965         vreinterpretq_u32_s32(result.val[2]);
5966     const uint32x4_t result_val_4_unsigned =
5967         vreinterpretq_u32_s32(result.val[3]);
5968 
5969     const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5970     const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5971     const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5972     const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5973     const uint16x8_t output_first_half =
5974         vcombine_u16(narrowed_val_1, narrowed_val_2);
5975     const uint16x8_t output_second_half =
5976         vcombine_u16(narrowed_val_3, narrowed_val_4);
5977     const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5978     const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5979     const uint8x16_t narrowed_result =
5980         vcombine_u8(narrowed_first_half, narrowed_second_half);
5981     vst1q_u8(output_data + i, narrowed_result);
5982   }
5983 
5984 #endif
5985   for (; i < size; ++i) {
5986     const int32_t input = input_data[i] - input_zeropoint;
5987     const int32_t output =
5988         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5989                                       effective_scale_shift) +
5990         output_zeropoint;
5991     const int32_t clamped_output =
5992         std::max(std::min(output, kMaxOutput), kMinOutput);
5993     output_data[i] = static_cast<uint8_t>(clamped_output);
5994   }
5995 }
5996 
HardSwish(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5997 inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
5998                       const RuntimeShape& output_shape, float* output_data) {
5999   ruy::profiler::ScopeLabel label("HardSwish/Float");
6000   auto size = MatchingFlatSize(input_shape, output_shape);
6001   int i = 0;
6002 #ifdef USE_NEON
6003   const float32x4_t zero = vdupq_n_f32(0.0f);
6004   const float32x4_t three = vdupq_n_f32(3.0f);
6005   const float32x4_t six = vdupq_n_f32(6.0f);
6006   const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f);
6007 
6008   for (; i <= size - 16; i += 16) {
6009     // 4x partially unrolled version of the loop below. Refer to its comments.
6010     const float32x4_t in_0 = vld1q_f32(input_data + i + 0);
6011     const float32x4_t in_1 = vld1q_f32(input_data + i + 4);
6012     const float32x4_t in_2 = vld1q_f32(input_data + i + 8);
6013     const float32x4_t in_3 = vld1q_f32(input_data + i + 12);
6014     const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth);
6015     const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth);
6016     const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth);
6017     const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth);
6018     const float32x4_t in_reluish_0 =
6019         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three)));
6020     const float32x4_t in_reluish_1 =
6021         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three)));
6022     const float32x4_t in_reluish_2 =
6023         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three)));
6024     const float32x4_t in_reluish_3 =
6025         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three)));
6026     const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0);
6027     const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1);
6028     const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2);
6029     const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3);
6030     vst1q_f32(output_data + i + 0, product_0);
6031     vst1q_f32(output_data + i + 4, product_1);
6032     vst1q_f32(output_data + i + 8, product_2);
6033     vst1q_f32(output_data + i + 12, product_3);
6034   }
6035   for (; i <= size - 4; i += 4) {
6036     // The expression to be computed is:
6037     //   out = one_sixth * in * min(six, max(zero, (in + three)))
6038     // We structure the AST to have two roughly balanced, independent branches:
6039     //  - Multiplication: in_scaled = one_sixth * in.
6040     //  - Addition and clamping: in_reluish = min(six, max(zero, (in + three))).
6041     // Then the remaining multiplication at the root of the tree.
6042     const float32x4_t in = vld1q_f32(input_data + i);
6043     const float32x4_t in_scaled = vmulq_f32(in, one_sixth);
6044     const float32x4_t in_reluish =
6045         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three)));
6046     const float32x4_t product = vmulq_f32(in_scaled, in_reluish);
6047     vst1q_f32(output_data + i, product);
6048   }
6049 #endif
6050   for (; i < size; i++) {
6051     const float in = input_data[i];
6052     output_data[i] =
6053         in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f);
6054   }
6055 }
6056 
6057 #ifdef USE_NEON
SaturateAndStore(int16x8_t src,std::uint8_t * dst)6058 inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) {
6059   // Narrow values down to 8 bit unsigned, saturating.
6060   uint8x8_t res8 = vqmovun_s16(src);
6061   // Store results to destination.
6062   vst1_u8(dst, res8);
6063 }
6064 
SaturateAndStore(int16x8_t src,std::int8_t * dst)6065 inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
6066   // Narrow values down to 8 bit unsigned, saturating.
6067   int8x8_t res8 = vqmovn_s16(src);
6068   // Store results to destination.
6069   vst1_s8(dst, res8);
6070 }
6071 #endif
6072 
6073 template <typename T>
HardSwish(const HardSwishParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6074 inline void HardSwish(const HardSwishParams& params,
6075                       const RuntimeShape& input_shape, const T* input_data,
6076                       const RuntimeShape& output_shape, T* output_data) {
6077   ruy::profiler::ScopeLabel label("HardSwish/Quantized");
6078 
6079   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6080 
6081   int i = 0;
6082   // This code heavily uses NEON saturating left shifts (vqshl*) with shift
6083   // amounts that can be zero, in which case we rely on the correct behavior
6084   // of a left shift by zero returning just its first operand unmodified.
6085   // Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is
6086   // buggy in the case of zero shift amounts, see b/137199585. That is why
6087   // this NEON code path is restricted to true ARM NEON, excluding
6088   // arm_neon_sse.h. Anyway, the arm_neon_sse.h implementation of saturating
6089   // left shifts is slow scalar code, so there may not be much benefit in
6090   // running that over just plain reference code.
6091   //
6092   // TODO(b/137199585): revisit when this is fixed.
6093 #ifdef __ARM_NEON
6094   const int16x8_t positive_reluish_multiplier_exponent_minus_one =
6095       vdupq_n_s16(std::max(0, params.reluish_multiplier_exponent - 1));
6096   const int16x8_t positive_reluish_multiplier_exponent_last_bit =
6097       vdupq_n_s16(params.reluish_multiplier_exponent > 0 ? 1 : 0);
6098   const int16x8_t negative_reluish_multiplier_exponent =
6099       vdupq_n_s16(std::min(0, params.reluish_multiplier_exponent));
6100   const int16x8_t constant_32767 = vdupq_n_s16(32767);
6101   const int16x8_t output_multiplier_exponent =
6102       vdupq_n_s16(params.output_multiplier_exponent);
6103   const int16x8_t output_zero_point = vdupq_n_s16(params.output_zero_point);
6104   // 4x unrolled version of the below NEON loop. Read that first.
6105   for (; i <= flat_size - 32; i += 32) {
6106     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6107     const int16x8x2_t input_value_0_1 =
6108         Load16AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6109     const int16x8x2_t input_value_2_3 = Load16AndSubtractZeroPoint(
6110         input_data + i + 16, params.input_zero_point);
6111     const int16x8_t input_value_on_hires_input_scale_0 =
6112         vshlq_n_s16(input_value_0_1.val[0], 7);
6113     const int16x8_t input_value_on_hires_input_scale_1 =
6114         vshlq_n_s16(input_value_0_1.val[1], 7);
6115     const int16x8_t input_value_on_hires_input_scale_2 =
6116         vshlq_n_s16(input_value_2_3.val[0], 7);
6117     const int16x8_t input_value_on_hires_input_scale_3 =
6118         vshlq_n_s16(input_value_2_3.val[1], 7);
6119     const int16x8_t input_value_on_preshift_output_scale_0 =
6120         vqrdmulhq_n_s16(input_value_on_hires_input_scale_0,
6121                         params.output_multiplier_fixedpoint_int16);
6122     const int16x8_t input_value_on_preshift_output_scale_1 =
6123         vqrdmulhq_n_s16(input_value_on_hires_input_scale_1,
6124                         params.output_multiplier_fixedpoint_int16);
6125     const int16x8_t input_value_on_preshift_output_scale_2 =
6126         vqrdmulhq_n_s16(input_value_on_hires_input_scale_2,
6127                         params.output_multiplier_fixedpoint_int16);
6128     const int16x8_t input_value_on_preshift_output_scale_3 =
6129         vqrdmulhq_n_s16(input_value_on_hires_input_scale_3,
6130                         params.output_multiplier_fixedpoint_int16);
6131     int16x8_t reluish_value_0 = input_value_on_hires_input_scale_0;
6132     int16x8_t reluish_value_1 = input_value_on_hires_input_scale_1;
6133     int16x8_t reluish_value_2 = input_value_on_hires_input_scale_2;
6134     int16x8_t reluish_value_3 = input_value_on_hires_input_scale_3;
6135     reluish_value_0 = vqshlq_s16(
6136         reluish_value_0, positive_reluish_multiplier_exponent_minus_one);
6137     reluish_value_1 = vqshlq_s16(
6138         reluish_value_1, positive_reluish_multiplier_exponent_minus_one);
6139     reluish_value_2 = vqshlq_s16(
6140         reluish_value_2, positive_reluish_multiplier_exponent_minus_one);
6141     reluish_value_3 = vqshlq_s16(
6142         reluish_value_3, positive_reluish_multiplier_exponent_minus_one);
6143     reluish_value_0 = vqrdmulhq_n_s16(
6144         reluish_value_0, params.reluish_multiplier_fixedpoint_int16);
6145     reluish_value_1 = vqrdmulhq_n_s16(
6146         reluish_value_1, params.reluish_multiplier_fixedpoint_int16);
6147     reluish_value_2 = vqrdmulhq_n_s16(
6148         reluish_value_2, params.reluish_multiplier_fixedpoint_int16);
6149     reluish_value_3 = vqrdmulhq_n_s16(
6150         reluish_value_3, params.reluish_multiplier_fixedpoint_int16);
6151     reluish_value_0 = vqshlq_s16(reluish_value_0,
6152                                  positive_reluish_multiplier_exponent_last_bit);
6153     reluish_value_1 = vqshlq_s16(reluish_value_1,
6154                                  positive_reluish_multiplier_exponent_last_bit);
6155     reluish_value_2 = vqshlq_s16(reluish_value_2,
6156                                  positive_reluish_multiplier_exponent_last_bit);
6157     reluish_value_3 = vqshlq_s16(reluish_value_3,
6158                                  positive_reluish_multiplier_exponent_last_bit);
6159     reluish_value_0 =
6160         vrshlq_s16(reluish_value_0, negative_reluish_multiplier_exponent);
6161     reluish_value_1 =
6162         vrshlq_s16(reluish_value_1, negative_reluish_multiplier_exponent);
6163     reluish_value_2 =
6164         vrshlq_s16(reluish_value_2, negative_reluish_multiplier_exponent);
6165     reluish_value_3 =
6166         vrshlq_s16(reluish_value_3, negative_reluish_multiplier_exponent);
6167     reluish_value_0 = vrhaddq_s16(reluish_value_0, constant_32767);
6168     reluish_value_1 = vrhaddq_s16(reluish_value_1, constant_32767);
6169     reluish_value_2 = vrhaddq_s16(reluish_value_2, constant_32767);
6170     reluish_value_3 = vrhaddq_s16(reluish_value_3, constant_32767);
6171     const int16x8_t preshift_output_value_0 =
6172         vqdmulhq_s16(reluish_value_0, input_value_on_preshift_output_scale_0);
6173     const int16x8_t preshift_output_value_1 =
6174         vqdmulhq_s16(reluish_value_1, input_value_on_preshift_output_scale_1);
6175     const int16x8_t preshift_output_value_2 =
6176         vqdmulhq_s16(reluish_value_2, input_value_on_preshift_output_scale_2);
6177     const int16x8_t preshift_output_value_3 =
6178         vqdmulhq_s16(reluish_value_3, input_value_on_preshift_output_scale_3);
6179     int16x8_t output_value_0 =
6180         vrshlq_s16(preshift_output_value_0, output_multiplier_exponent);
6181     int16x8_t output_value_1 =
6182         vrshlq_s16(preshift_output_value_1, output_multiplier_exponent);
6183     int16x8_t output_value_2 =
6184         vrshlq_s16(preshift_output_value_2, output_multiplier_exponent);
6185     int16x8_t output_value_3 =
6186         vrshlq_s16(preshift_output_value_3, output_multiplier_exponent);
6187     output_value_0 = vaddq_s16(output_value_0, output_zero_point);
6188     output_value_1 = vaddq_s16(output_value_1, output_zero_point);
6189     output_value_2 = vaddq_s16(output_value_2, output_zero_point);
6190     output_value_3 = vaddq_s16(output_value_3, output_zero_point);
6191     SaturateAndStore(output_value_0, output_data + i);
6192     SaturateAndStore(output_value_1, output_data + i + 8);
6193     SaturateAndStore(output_value_2, output_data + i + 16);
6194     SaturateAndStore(output_value_3, output_data + i + 24);
6195   }
6196   // NEON version of reference_ops::HardSwish. Read that first.
6197   for (; i <= flat_size - 8; i += 8) {
6198     using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
6199     const int16x8_t input_value =
6200         Load8AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6201     const int16x8_t input_value_on_hires_input_scale =
6202         vshlq_n_s16(input_value, 7);
6203     const int16x8_t input_value_on_preshift_output_scale =
6204         vqrdmulhq_n_s16(input_value_on_hires_input_scale,
6205                         params.output_multiplier_fixedpoint_int16);
6206     int16x8_t reluish_value = input_value_on_hires_input_scale;
6207     reluish_value = vqshlq_s16(reluish_value,
6208                                positive_reluish_multiplier_exponent_minus_one);
6209     reluish_value = vqrdmulhq_n_s16(reluish_value,
6210                                     params.reluish_multiplier_fixedpoint_int16);
6211     reluish_value = vqshlq_s16(reluish_value,
6212                                positive_reluish_multiplier_exponent_last_bit);
6213     reluish_value =
6214         vrshlq_s16(reluish_value, negative_reluish_multiplier_exponent);
6215     reluish_value = vrhaddq_s16(reluish_value, constant_32767);
6216     const int16x8_t preshift_output_value =
6217         vqdmulhq_s16(reluish_value, input_value_on_preshift_output_scale);
6218     int16x8_t output_value =
6219         vrshlq_s16(preshift_output_value, output_multiplier_exponent);
6220     output_value = vaddq_s16(output_value, output_zero_point);
6221     SaturateAndStore(output_value, output_data + i);
6222   }
6223 #endif
6224   // TODO(b/137208495): revisit when unit tests cover reference code.
6225   // Fall back to reference_ops::HardSwish. In general we have preferred
6226   // to duplicate such scalar code rather than call reference code to handle
6227   // leftovers, thinking that code duplication was not a big concern.
6228   // However, most of our unit tests happen to test only optimized code,
6229   // and the quantized HardSwish implementation is nontrivial enough that
6230   // I really want test coverage for the reference code.
6231   if (i < flat_size) {
6232     const RuntimeShape leftover_shape{flat_size - i};
6233     reference_ops::HardSwish(params, leftover_shape, input_data + i,
6234                              leftover_shape, output_data + i);
6235   }
6236 }
6237 
6238 template <typename T>
IntegerExponentPow(const ArithmeticParams & params,const RuntimeShape & unextended_base_shape,const T * base_data,const int exponent,const RuntimeShape & unextended_output_shape,T * output_data)6239 inline void IntegerExponentPow(const ArithmeticParams& params,
6240                                const RuntimeShape& unextended_base_shape,
6241                                const T* base_data, const int exponent,
6242                                const RuntimeShape& unextended_output_shape,
6243                                T* output_data) {
6244   TFLITE_DCHECK_GE(exponent, 1);
6245   if (exponent == 1) {
6246     // copy data over.
6247     std::memcpy(output_data, base_data,
6248                 unextended_base_shape.FlatSize() * sizeof(T));
6249   } else {
6250     IntegerExponentPow(params, unextended_base_shape, base_data, exponent / 2,
6251                        unextended_output_shape, output_data);
6252     Mul(params, unextended_base_shape, output_data, unextended_base_shape,
6253         output_data, unextended_output_shape, output_data);
6254     if (exponent % 2 == 1) {
6255       Mul(params, unextended_base_shape, base_data, unextended_base_shape,
6256           output_data, unextended_output_shape, output_data);
6257     }
6258   }
6259 }
6260 
6261 template <typename T>
BroadcastPow4D(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)6262 inline void BroadcastPow4D(const RuntimeShape& unextended_input1_shape,
6263                            const T* input1_data,
6264                            const RuntimeShape& unextended_input2_shape,
6265                            const T* input2_data,
6266                            const RuntimeShape& unextended_output_shape,
6267                            T* output_data) {
6268   ruy::profiler::ScopeLabel label("PowBroadcast");
6269 
6270   if (unextended_input2_shape.FlatSize() == 1) {
6271     static const float epsilon = 1e-5;
6272     const T exponent = input2_data[0];
6273     const int int_exponent = static_cast<int>(std::round(exponent));
6274     if ((std::abs(input2_data[0] - int_exponent) < epsilon) &&
6275         (int_exponent >= 1)) {
6276       ArithmeticParams params;
6277       if (std::is_same<T, float>::value) {
6278         params.float_activation_max = std::numeric_limits<float>::max();
6279         params.float_activation_min = std::numeric_limits<float>::lowest();
6280       } else if (std::is_same<T, int>::value) {
6281         params.quantized_activation_max = std::numeric_limits<int>::max();
6282         params.quantized_activation_min = std::numeric_limits<int>::lowest();
6283       }
6284       IntegerExponentPow(params, unextended_input1_shape, input1_data,
6285                          int_exponent, unextended_output_shape, output_data);
6286       return;
6287     }
6288   }
6289   reference_ops::BroadcastPow4DSlow(unextended_input1_shape, input1_data,
6290                                     unextended_input2_shape, input2_data,
6291                                     unextended_output_shape, output_data);
6292 }
6293 
6294 #ifdef USE_NEON
6295 
ScaleWithNewZeroPoint(const int32x4_t input,const float32x4_t scale_dup,const float32x4_t zero_times_scale_dup,float32x4_t * output)6296 inline void ScaleWithNewZeroPoint(const int32x4_t input,
6297                                   const float32x4_t scale_dup,
6298                                   const float32x4_t zero_times_scale_dup,
6299                                   float32x4_t* output) {
6300 #ifdef __ARM_FEATURE_FMA
6301   *output = vfmaq_f32(zero_times_scale_dup, vcvtq_f32_s32(input), scale_dup);
6302 #else
6303   *output = vaddq_f32(vmulq_f32(vcvtq_f32_s32(input), scale_dup),
6304                       zero_times_scale_dup);
6305 #endif
6306 }
6307 
6308 #endif  // USE_NEON
6309 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & output_shape,float * output_data)6310 inline void Dequantize(const tflite::DequantizationParams& op_params,
6311                        const RuntimeShape& input_shape,
6312                        const uint8_t* input_data,
6313                        const RuntimeShape& output_shape, float* output_data) {
6314   ruy::profiler::ScopeLabel label("Dequantize/Uint8");
6315   const int32 zero_point = op_params.zero_point;
6316   const double scale = op_params.scale;
6317   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6318 
6319   int i = 0;
6320 #ifdef USE_NEON
6321   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6322   const float32x4_t zero_times_scale_dup =
6323       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6324   for (; i <= flat_size - 8; i += 8) {
6325     const uint8x8_t input_u8 = vld1_u8(input_data + i);
6326     const uint16x8_t input_u16 = vmovl_u8(input_u8);
6327     const int16x8_t input_s16 = vreinterpretq_s16_u16(input_u16);
6328     const int16x4_t input_s16_low = vget_low_s16(input_s16);
6329     const int16x4_t input_s16_high = vget_high_s16(input_s16);
6330     const int32x4_t val_low = vmovl_s16(input_s16_low);
6331     const int32x4_t val_high = vmovl_s16(input_s16_high);
6332 
6333     float32x4_t result_low, result_high;
6334     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6335                           &result_low);
6336     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6337                           &result_high);
6338 
6339     vst1q_f32(output_data + i, result_low);
6340     vst1q_f32(output_data + i + 4, result_high);
6341   }
6342 #endif  // NEON
6343   for (; i < flat_size; ++i) {
6344     const int32 val = input_data[i];
6345     const float result = static_cast<float>(scale * (val - zero_point));
6346     output_data[i] = result;
6347   }
6348 }
6349 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & output_shape,float * output_data)6350 inline void Dequantize(const tflite::DequantizationParams& op_params,
6351                        const RuntimeShape& input_shape,
6352                        const int8_t* input_data,
6353                        const RuntimeShape& output_shape, float* output_data) {
6354   ruy::profiler::ScopeLabel label("Dequantize/Int8");
6355   const int32 zero_point = op_params.zero_point;
6356   const double scale = op_params.scale;
6357   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6358 
6359   int i = 0;
6360 #ifdef USE_NEON
6361   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6362   const float32x4_t zero_times_scale_dup =
6363       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6364   for (; i <= flat_size - 8; i += 8) {
6365     const int8x8_t input_s8 = vld1_s8(input_data + i);
6366     const int16x8_t input_s16 = vmovl_s8(input_s8);
6367     const int16x4_t input_s16_low = vget_low_s16(input_s16);
6368     const int16x4_t input_s16_high = vget_high_s16(input_s16);
6369     const int32x4_t val_low = vmovl_s16(input_s16_low);
6370     const int32x4_t val_high = vmovl_s16(input_s16_high);
6371 
6372     float32x4_t result_low, result_high;
6373     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6374                           &result_low);
6375     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6376                           &result_high);
6377 
6378     vst1q_f32(output_data + i, result_low);
6379     vst1q_f32(output_data + i + 4, result_high);
6380   }
6381 #endif  // NEON
6382   for (; i < flat_size; ++i) {
6383     const int32 val = input_data[i];
6384     const float result = static_cast<float>(scale * (val - zero_point));
6385     output_data[i] = result;
6386   }
6387 }
6388 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,float * output_data)6389 inline void Dequantize(const tflite::DequantizationParams& op_params,
6390                        const RuntimeShape& input_shape,
6391                        const int16_t* input_data,
6392                        const RuntimeShape& output_shape, float* output_data) {
6393   ruy::profiler::ScopeLabel label("Dequantize/Int16");
6394   const int32 zero_point = op_params.zero_point;
6395   const double scale = op_params.scale;
6396   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6397 
6398   int i = 0;
6399 #ifdef USE_NEON
6400   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6401   const float32x4_t zero_times_scale_dup =
6402       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6403   for (; i <= flat_size - 8; i += 8) {
6404     const int16x4_t input_s16_low = vld1_s16(input_data + i);
6405     const int16x4_t input_s16_high = vld1_s16(input_data + i + 4);
6406     const int32x4_t val_low = vmovl_s16(input_s16_low);
6407     const int32x4_t val_high = vmovl_s16(input_s16_high);
6408 
6409     float32x4_t result_low, result_high;
6410     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6411                           &result_low);
6412     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6413                           &result_high);
6414 
6415     vst1q_f32(output_data + i, result_low);
6416     vst1q_f32(output_data + i + 4, result_high);
6417   }
6418 #endif  // NEON
6419   for (; i < flat_size; ++i) {
6420     const int32 val = input_data[i];
6421     const float result = static_cast<float>(scale * (val - zero_point));
6422     output_data[i] = result;
6423   }
6424 }
6425 
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)6426 inline void Dequantize(const RuntimeShape& input_shape,
6427                        const Eigen::half* input_data,
6428                        const RuntimeShape& output_shape, float* output_data) {
6429   reference_ops::Dequantize(input_shape, input_data, output_shape, output_data);
6430 }
6431 
6432 template <typename T>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,T * output_data)6433 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6434                            const RuntimeShape& input_shape,
6435                            const float* input_data,
6436                            const RuntimeShape& output_shape, T* output_data) {
6437   reference_ops::AffineQuantize(op_params, input_shape, input_data,
6438                                 output_shape, output_data);
6439 }
6440 
6441 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int8_t * output_data)6442 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6443                            const RuntimeShape& input_shape,
6444                            const float* input_data,
6445                            const RuntimeShape& output_shape,
6446                            int8_t* output_data) {
6447   ruy::profiler::ScopeLabel label("Quantize/Int8");
6448   const int32 zero_point = op_params.zero_point;
6449   const double scale = static_cast<double>(op_params.scale);
6450   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6451   static constexpr int32 min_val = std::numeric_limits<int8_t>::min();
6452   static constexpr int32 max_val = std::numeric_limits<int8_t>::max();
6453 
6454   int i = 0;
6455 #ifdef USE_NEON
6456   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6457   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6458   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6459   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6460 
6461   for (; i <= flat_size - 8; i += 8) {
6462     const float* src_data_ptr = input_data + i;
6463     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6464     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6465 
6466     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6467     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6468 
6469     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6470     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6471 
6472     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6473     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6474 
6475     // Clamp the values to fit the target type's range.
6476     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6477     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6478     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6479     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6480 
6481     const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6482     const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6483     const int16x8_t combined_val = vcombine_s16(narrowed_val_0, narrowed_val_1);
6484     const int8x8_t combined_val_narrowed = vmovn_s16(combined_val);
6485     vst1_s8(output_data + i, combined_val_narrowed);
6486   }
6487 #endif  // NEON
6488 
6489   for (; i < flat_size; ++i) {
6490     const float val = input_data[i];
6491     const int32 unclamped =
6492         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6493     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6494     output_data[i] = clamped;
6495   }
6496 }
6497 
6498 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,uint8_t * output_data)6499 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6500                            const RuntimeShape& input_shape,
6501                            const float* input_data,
6502                            const RuntimeShape& output_shape,
6503                            uint8_t* output_data) {
6504   ruy::profiler::ScopeLabel label("Quantize/Uint8");
6505   const int32 zero_point = op_params.zero_point;
6506   const double scale = static_cast<double>(op_params.scale);
6507   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6508   static constexpr int32 min_val = std::numeric_limits<uint8_t>::min();
6509   static constexpr int32 max_val = std::numeric_limits<uint8_t>::max();
6510 
6511   int i = 0;
6512 #ifdef USE_NEON
6513   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6514   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6515   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6516   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6517 
6518   for (; i <= flat_size - 8; i += 8) {
6519     const float* src_data_ptr = input_data + i;
6520     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6521     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6522 
6523     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6524     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6525 
6526     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6527     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6528 
6529     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6530     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6531 
6532     // Clamp the values to fit the target type's range.
6533     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6534     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6535     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6536     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6537 
6538     const uint16x4_t narrowed_val_0 = vqmovun_s32(casted_val_0);
6539     const uint16x4_t narrowed_val_1 = vqmovun_s32(casted_val_1);
6540     const uint16x8_t combined_val =
6541         vcombine_u16(narrowed_val_0, narrowed_val_1);
6542     const uint8x8_t combined_val_narrowed = vmovn_u16(combined_val);
6543     vst1_u8(output_data + i, combined_val_narrowed);
6544   }
6545 #endif  // NEON
6546 
6547   for (; i < flat_size; ++i) {
6548     const float val = input_data[i];
6549     const int32 unclamped =
6550         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6551     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6552     output_data[i] = clamped;
6553   }
6554 }
6555 
6556 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int16_t * output_data)6557 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6558                            const RuntimeShape& input_shape,
6559                            const float* input_data,
6560                            const RuntimeShape& output_shape,
6561                            int16_t* output_data) {
6562   ruy::profiler::ScopeLabel label("Quantize/Int16");
6563   const int32 zero_point = op_params.zero_point;
6564   const double scale = static_cast<double>(op_params.scale);
6565   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6566   static constexpr int32 min_val = std::numeric_limits<int16_t>::min();
6567   static constexpr int32 max_val = std::numeric_limits<int16_t>::max();
6568 
6569   int i = 0;
6570 #ifdef USE_NEON
6571   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6572   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6573   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6574   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6575 
6576   for (; i <= flat_size - 8; i += 8) {
6577     const float* src_data_ptr = input_data + i;
6578     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6579     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6580 
6581     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6582     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6583 
6584     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6585     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6586 
6587     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6588     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6589 
6590     // Clamp the values to fit the target type's range.
6591     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6592     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6593     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6594     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6595 
6596     const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6597     const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6598     vst1_s16(output_data + i, narrowed_val_0);
6599     vst1_s16(output_data + i + 4, narrowed_val_1);
6600   }
6601 #endif  // NEON
6602 
6603   for (; i < flat_size; ++i) {
6604     const float val = input_data[i];
6605     const int32 unclamped =
6606         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6607     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6608     output_data[i] = clamped;
6609   }
6610 }
6611 
6612 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6613 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6614 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6615 #ifdef GEMMLOWP_NEON
6616 
SaturatingRounding(int16x8_t input_val_0,int16x8_t input_val_1,int16x8_t input_val_2,int16x8_t input_val_3,int input_left_shift,int input_multiplier)6617 inline int16x8x4_t SaturatingRounding(
6618     int16x8_t input_val_0, int16x8_t input_val_1, int16x8_t input_val_2,
6619     int16x8_t input_val_3, int input_left_shift, int input_multiplier) {
6620   // This performs what is expressed in the scalar code as
6621   // const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6622   //      static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6623   //      static_cast<int16>(input_multiplier));
6624   const int16x8_t left_shift_dup = vdupq_n_s16(input_left_shift);
6625   const int16x8_t input_val_shifted_0 = vshlq_s16(input_val_0, left_shift_dup);
6626   const int16x8_t input_val_shifted_1 = vshlq_s16(input_val_1, left_shift_dup);
6627   const int16x8_t input_val_shifted_2 = vshlq_s16(input_val_2, left_shift_dup);
6628   const int16x8_t input_val_shifted_3 = vshlq_s16(input_val_3, left_shift_dup);
6629   int16x8x4_t result;
6630   result.val[0] = vqrdmulhq_n_s16(input_val_shifted_0, input_multiplier);
6631   result.val[1] = vqrdmulhq_n_s16(input_val_shifted_1, input_multiplier);
6632   result.val[2] = vqrdmulhq_n_s16(input_val_shifted_2, input_multiplier);
6633   result.val[3] = vqrdmulhq_n_s16(input_val_shifted_3, input_multiplier);
6634   return result;
6635 }
6636 
6637 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6638 // considering 7 digits under zero.
FixedPoint4Logistic(int16x8x4_t input_val)6639 inline int16x8x4_t FixedPoint4Logistic(int16x8x4_t input_val) {
6640   // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6641   using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6642   using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6643   const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6644   const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6645   const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6646   const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6647 
6648   // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6649   // method, gemmlowp::tanh spends about 80% of the execution times. The
6650   // current implementation is rougly 12-bit accurate in the 16-bit fixed
6651   // point case. Until reaching to error bounds, there are rooms for
6652   // improvements.
6653   const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
6654   const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
6655   const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
6656   const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
6657 
6658   // Divide by 2^7 as in the scalar code
6659   int16x8x4_t result;
6660   result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 7);
6661   result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 7);
6662   result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 7);
6663   result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 7);
6664   return result;
6665 }
6666 
6667 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6668 // considering 11 digits under zero at least.
FixedPoint4Tanh(int16x8x4_t input_val)6669 inline int16x8x4_t FixedPoint4Tanh(int16x8x4_t input_val) {
6670   // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6671   using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6672   using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6673   const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6674   const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6675   const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6676   const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6677 
6678   // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6679   // method, gemmlowp::tanh spends about 80% of the execution times. The
6680   // current implementation is rougly 12-bit accurate in the 16-bit fixed
6681   // point case. Until reaching to error bounds, there are rooms for
6682   // improvements.
6683   const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
6684   const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
6685   const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
6686   const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
6687 
6688   // Divide by 2^7 as in the scalar code
6689   int16x8x4_t result;
6690   result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 8);
6691   result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 8);
6692   result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 8);
6693   result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 8);
6694   return result;
6695 }
6696 
CalculateUnsignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6697 inline uint8x16x2_t CalculateUnsignedClampingWithRangeBitMasks(
6698     int16x8x2_t input_val, int16x8_t range_radius_dup,
6699     int16x8_t neg_range_radius_dup) {
6700   const uint16x8_t mask_rightclamp_0 =
6701       vcgtq_s16(input_val.val[0], range_radius_dup);
6702   const uint16x8_t mask_rightclamp_1 =
6703       vcgtq_s16(input_val.val[1], range_radius_dup);
6704 
6705   const uint16x8_t mask_leftclamp_0 =
6706       vcgeq_s16(input_val.val[0], neg_range_radius_dup);
6707   const uint16x8_t mask_leftclamp_1 =
6708       vcgeq_s16(input_val.val[1], neg_range_radius_dup);
6709 
6710   uint8x16x2_t result;
6711   result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6712                               vshrn_n_u16(mask_leftclamp_1, 8));
6713   result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6714                               vshrn_n_u16(mask_rightclamp_1, 8));
6715   return result;
6716 }
6717 
CalculateSignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6718 inline uint8x16x2_t CalculateSignedClampingWithRangeBitMasks(
6719     int16x8x2_t input_val, int16x8_t range_radius_dup,
6720     int16x8_t neg_range_radius_dup) {
6721   const uint16x8_t mask_rightclamp_0 =
6722       vcgtq_s16(input_val.val[0], range_radius_dup);
6723   const uint16x8_t mask_rightclamp_1 =
6724       vcgtq_s16(input_val.val[1], range_radius_dup);
6725 
6726   const uint16x8_t mask_leftclamp_0 =
6727       vcltq_s16(input_val.val[0], neg_range_radius_dup);
6728   const uint16x8_t mask_leftclamp_1 =
6729       vcltq_s16(input_val.val[1], neg_range_radius_dup);
6730 
6731   uint8x16x2_t result;
6732   result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6733                               vshrn_n_u16(mask_leftclamp_1, 8));
6734   result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6735                               vshrn_n_u16(mask_rightclamp_1, 8));
6736   return result;
6737 }
6738 
ClampWithRangeAndStore(uint8_t * output_dst,uint8x16_t input_val,uint8x16x2_t masks_clamp)6739 inline void ClampWithRangeAndStore(uint8_t* output_dst, uint8x16_t input_val,
6740                                    uint8x16x2_t masks_clamp) {
6741   // Store back to memory
6742   vst1q_u8(output_dst, vandq_u8(vorrq_u8(input_val, masks_clamp.val[1]),
6743                                 masks_clamp.val[0]));
6744 }
6745 
ClampWithRangeAndStore(int8_t * output_dst,int8x16_t input_val,uint8x16x2_t masks_clamp)6746 inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val,
6747                                    uint8x16x2_t masks_clamp) {
6748   static const int8x16_t max_dup = vdupq_n_s8(127);
6749   static const int8x16_t min_dup = vdupq_n_s8(-128);
6750   // Store back to memory
6751   vst1q_s8(output_dst,
6752            vbslq_s8(masks_clamp.val[1], max_dup,
6753                     vbslq_s8(masks_clamp.val[0], min_dup, input_val)));
6754 }
6755 
6756 #endif  // GEMMLOWP_NEON
6757 
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6758 inline void Tanh16bitPrecision(const TanhParams& params,
6759                                const RuntimeShape& input_shape,
6760                                const uint8* input_data,
6761                                const RuntimeShape& output_shape,
6762                                uint8* output_data) {
6763   // Note that this is almost the exact same code as in Logistic().
6764   ruy::profiler::ScopeLabel label("Tanh/Uint8");
6765   const int32 input_zero_point = params.input_zero_point;
6766   const int32 input_range_radius = params.input_range_radius;
6767   const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6768   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6769   const int size = MatchingFlatSize(input_shape, output_shape);
6770 
6771   int c = 0;
6772   int16_t output_zero_point = 128;
6773 
6774 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6775 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6776 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6777 #ifdef GEMMLOWP_NEON
6778   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6779   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6780   const int16x8_t output_zero_point_s16 = vdupq_n_s16(output_zero_point);
6781 
6782   // Handle 32 values at a time
6783   for (; c <= size - 32; c += 32) {
6784     // Read input uint8 values, cast to int16 and subtract input_zero_point
6785     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6786     const int16x8x2_t input_val_centered_0_1 =
6787         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6788     const int16x8x2_t input_val_centered_2_3 =
6789         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6790 
6791     // Prepare the bit masks that we will use at the end to implement the logic
6792     // that was expressed in the scalar code with branching:
6793     //   if (input_val_centered < -input_range_radius) {
6794     //     output_val = 0;
6795     //   } else if (input_val_centered > input_range_radius) {
6796     //     output_val = 255;
6797     //   } else {
6798     //     ...
6799     uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6800         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6801     uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6802         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6803 
6804     int16x8x4_t input_val_rescaled = SaturatingRounding(
6805         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6806         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6807         input_left_shift, input_multiplier);
6808 
6809     int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6810 
6811     // Add the output zero point
6812     output_val_s16.val[0] =
6813         vaddq_s16(output_val_s16.val[0], output_zero_point_s16);
6814     output_val_s16.val[1] =
6815         vaddq_s16(output_val_s16.val[1], output_zero_point_s16);
6816     output_val_s16.val[2] =
6817         vaddq_s16(output_val_s16.val[2], output_zero_point_s16);
6818     output_val_s16.val[3] =
6819         vaddq_s16(output_val_s16.val[3], output_zero_point_s16);
6820 
6821     // Cast output values to uint8, saturating
6822     uint8x16_t output_val_u8_0_1 = vcombine_u8(
6823         vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
6824     uint8x16_t output_val_u8_2_3 = vcombine_u8(
6825         vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
6826 
6827     ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
6828     ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
6829                            masks_clamp_2_3);
6830   }
6831 #endif  // GEMMLOWP_NEON
6832   // Leftover loop: handle one value at a time with scalar code.
6833   for (; c < size; ++c) {
6834     const uint8 input_val_u8 = input_data[c];
6835     const int16 input_val_centered =
6836         static_cast<int16>(input_val_u8) - input_zero_point;
6837     uint8 output_val;
6838     if (input_val_centered < -input_range_radius) {
6839       output_val = 0;
6840     } else if (input_val_centered > input_range_radius) {
6841       output_val = 255;
6842     } else {
6843       using gemmlowp::SaturatingRoundingDoublingHighMul;
6844       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6845           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6846           static_cast<int16>(input_multiplier));
6847       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6848       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6849       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6850       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6851       using gemmlowp::RoundingDivideByPOT;
6852       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6853       output_val_s16 += output_zero_point;
6854       if (output_val_s16 == 256) {
6855         output_val_s16 = 255;
6856       }
6857       TFLITE_DCHECK_GE(output_val_s16, 0);
6858       TFLITE_DCHECK_LE(output_val_s16, 255);
6859       output_val = static_cast<uint8>(output_val_s16);
6860     }
6861     output_data[c] = output_val;
6862   }
6863 }
6864 
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)6865 inline void Tanh16bitPrecision(const TanhParams& params,
6866                                const RuntimeShape& input_shape,
6867                                const int8* input_data,
6868                                const RuntimeShape& output_shape,
6869                                int8* output_data) {
6870   // Note that this is almost the exact same code as in Logistic().
6871   ruy::profiler::ScopeLabel label("Tanh/Int8");
6872   const int32 input_zero_point = params.input_zero_point;
6873   const int32 input_range_radius = params.input_range_radius;
6874   const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6875   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6876   const int size = MatchingFlatSize(input_shape, output_shape);
6877 
6878   int c = 0;
6879 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6880 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6881 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6882 #ifdef GEMMLOWP_NEON
6883   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6884   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6885 
6886   // Handle 32 values at a time
6887   for (; c <= size - 32; c += 32) {
6888     // Read input int8 values, cast to int16 and subtract input_zero_point
6889     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6890     const int16x8x2_t input_val_centered_0_1 =
6891         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6892     const int16x8x2_t input_val_centered_2_3 =
6893         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6894 
6895     // Prepare the bit masks that we will use at the end to implement the logic
6896     // that was expressed in the scalar code with branching:
6897     //   if (input_val_centered < -input_range_radius) {
6898     //     output_val = -128;
6899     //   } else if (input_val_centered > input_range_radius) {
6900     //     output_val = 127;
6901     //   } else {
6902     //     ...
6903     uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
6904         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6905     uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
6906         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6907 
6908     int16x8x4_t input_val_rescaled = SaturatingRounding(
6909         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6910         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6911         input_left_shift, input_multiplier);
6912 
6913     int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6914 
6915     // Cast output values to uint8, saturating
6916     int8x16_t output_val_s8_0_1 = vcombine_s8(
6917         vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
6918     int8x16_t output_val_s8_2_3 = vcombine_s8(
6919         vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
6920 
6921     ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
6922     ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
6923                            masks_clamp_2_3);
6924   }
6925 #endif  // GEMMLOWP_NEON
6926   // Leftover loop: handle one value at a time with scalar code.
6927   for (; c < size; ++c) {
6928     const int8 input_val_s8 = input_data[c];
6929     const int16 input_val_centered =
6930         static_cast<int16>(input_val_s8) - input_zero_point;
6931     int8 output_val;
6932     if (input_val_centered <= -input_range_radius) {
6933       output_val = -128;
6934     } else if (input_val_centered >= input_range_radius) {
6935       output_val = 127;
6936     } else {
6937       using gemmlowp::SaturatingRoundingDoublingHighMul;
6938       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6939           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6940           static_cast<int16>(input_multiplier));
6941       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6942       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6943       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6944       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6945       using gemmlowp::RoundingDivideByPOT;
6946       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6947       if (output_val_s16 == 128) {
6948         output_val_s16 = 127;
6949       }
6950       TFLITE_DCHECK_GE(output_val_s16, -128);
6951       TFLITE_DCHECK_LE(output_val_s16, 127);
6952       output_val = static_cast<int8>(output_val_s16);
6953     }
6954     output_data[c] = output_val;
6955   }
6956 }
6957 
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6958 inline void Logistic16bitPrecision(const LogisticParams& params,
6959                                    const RuntimeShape& input_shape,
6960                                    const uint8* input_data,
6961                                    const RuntimeShape& output_shape,
6962                                    uint8* output_data) {
6963   ruy::profiler::ScopeLabel label("Logistic/Uint8");
6964   const int32 input_zero_point = params.input_zero_point;
6965   const int32 input_range_radius = params.input_range_radius;
6966   const int32 input_multiplier = params.input_multiplier;
6967   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6968   const int size = MatchingFlatSize(input_shape, output_shape);
6969 
6970   int c = 0;
6971 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6972 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6973 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6974 #ifdef GEMMLOWP_NEON
6975   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6976   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6977 
6978   // Handle 32 values at a time
6979   for (; c <= size - 32; c += 32) {
6980     // Read input uint8 values, cast to int16 and subtract input_zero_point
6981     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6982     const int16x8x2_t input_val_centered_0_1 =
6983         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6984     const int16x8x2_t input_val_centered_2_3 =
6985         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6986 
6987     // Prepare the bit masks that we will use at the end to implement the logic
6988     // that was expressed in the scalar code with branching:
6989     //   if (input_val_centered < -input_range_radius) {
6990     //     output_val = 0;
6991     //   } else if (input_val_centered > input_range_radius) {
6992     //     output_val = 255;
6993     //   } else {
6994     //     ...
6995     uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6996         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6997     uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6998         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6999 
7000     int16x8x4_t input_val_rescaled = SaturatingRounding(
7001         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7002         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7003         input_left_shift, input_multiplier);
7004 
7005     int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7006 
7007     // Cast output values to uint8, saturating
7008     uint8x16_t output_val_u8_0_1 = vcombine_u8(
7009         vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
7010     uint8x16_t output_val_u8_2_3 = vcombine_u8(
7011         vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
7012 
7013     ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
7014     ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
7015                            masks_clamp_2_3);
7016   }
7017 #endif  // GEMMLOWP_NEON
7018   // Leftover loop: handle one value at a time with scalar code.
7019   for (; c < size; ++c) {
7020     const uint8 input_val_u8 = input_data[c];
7021     const int16 input_val_centered =
7022         static_cast<int16>(input_val_u8) - input_zero_point;
7023     uint8 output_val;
7024     if (input_val_centered < -input_range_radius) {
7025       output_val = 0;
7026     } else if (input_val_centered > input_range_radius) {
7027       output_val = 255;
7028     } else {
7029       using gemmlowp::SaturatingRoundingDoublingHighMul;
7030       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7031           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7032           static_cast<int16>(input_multiplier));
7033       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7034       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7035       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7036       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7037       using gemmlowp::RoundingDivideByPOT;
7038       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7039       if (output_val_s16 == 256) {
7040         output_val_s16 = 255;
7041       }
7042       TFLITE_DCHECK_GE(output_val_s16, 0);
7043       TFLITE_DCHECK_LE(output_val_s16, 255);
7044       output_val = static_cast<uint8>(output_val_s16);
7045     }
7046     output_data[c] = output_val;
7047   }
7048 }
7049 
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7050 inline void Logistic16bitPrecision(const LogisticParams& params,
7051                                    const RuntimeShape& input_shape,
7052                                    const int8* input_data,
7053                                    const RuntimeShape& output_shape,
7054                                    int8* output_data) {
7055   ruy::profiler::ScopeLabel label("Logistic/Int8");
7056   const int32 input_zero_point = params.input_zero_point;
7057   const int32 input_range_radius = params.input_range_radius;
7058   const int32 input_multiplier = params.input_multiplier;
7059   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7060   const int size = MatchingFlatSize(input_shape, output_shape);
7061 
7062   int c = 0;
7063   const int16 output_zero_point = 128;
7064 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7065 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7066 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7067 #ifdef GEMMLOWP_NEON
7068   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7069   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7070   const int16x8_t output_zero_point_dup = vdupq_n_s16(output_zero_point);
7071 
7072   // Handle 32 values at a time
7073   for (; c <= size - 32; c += 32) {
7074     // Read input int8 values, cast to int16 and subtract input_zero_point
7075     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7076     const int16x8x2_t input_val_centered_0_1 =
7077         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7078     const int16x8x2_t input_val_centered_2_3 =
7079         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7080 
7081     // Prepare the bit masks that we will use at the end to implement the logic
7082     // that was expressed in the scalar code with branching:
7083     //   if (input_val_centered < -input_range_radius) {
7084     //     output_val = -128;
7085     //   } else if (input_val_centered > input_range_radius) {
7086     //     output_val = 127;
7087     //   } else {
7088     //     ...
7089     uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7090         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7091     uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7092         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7093 
7094     int16x8x4_t input_val_rescaled = SaturatingRounding(
7095         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7096         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7097         input_left_shift, input_multiplier);
7098 
7099     int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7100 
7101     // Substract output zero point.
7102     output_val_s16.val[0] =
7103         vsubq_s16(output_val_s16.val[0], output_zero_point_dup);
7104     output_val_s16.val[1] =
7105         vsubq_s16(output_val_s16.val[1], output_zero_point_dup);
7106     output_val_s16.val[2] =
7107         vsubq_s16(output_val_s16.val[2], output_zero_point_dup);
7108     output_val_s16.val[3] =
7109         vsubq_s16(output_val_s16.val[3], output_zero_point_dup);
7110 
7111     // Cast output values to int8, saturating
7112     int8x16_t output_val_s8_0_1 = vcombine_s8(
7113         vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7114     int8x16_t output_val_s8_2_3 = vcombine_s8(
7115         vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7116 
7117     ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7118     ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7119                            masks_clamp_2_3);
7120   }
7121 #endif  // GEMMLOWP_NEON
7122   // Leftover loop: handle one value at a time with scalar code.
7123   for (; c < size; ++c) {
7124     const int8 input_val_s8 = input_data[c];
7125     const int16 input_val_centered =
7126         static_cast<int16>(input_val_s8) - input_zero_point;
7127     int8 output_val;
7128     if (input_val_centered < -input_range_radius) {
7129       output_val = -128;
7130     } else if (input_val_centered > input_range_radius) {
7131       output_val = 127;
7132     } else {
7133       using gemmlowp::SaturatingRoundingDoublingHighMul;
7134       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7135           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7136           static_cast<int16>(input_multiplier));
7137       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7138       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7139       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7140       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7141       using gemmlowp::RoundingDivideByPOT;
7142       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7143       output_val_s16 -= output_zero_point;
7144       if (output_val_s16 == 128) {
7145         output_val_s16 = 127;
7146       }
7147       TFLITE_DCHECK_GE(output_val_s16, -128);
7148       TFLITE_DCHECK_LE(output_val_s16, 127);
7149       output_val = static_cast<int8>(output_val_s16);
7150     }
7151     output_data[c] = output_val;
7152   }
7153 }
7154 
7155 // Transpose2D only deals with typical 2D matrix transpose ops.
7156 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
7157 // left to right (down the rows) of the input, and then from top to bottom.
7158 template <typename T>
Transpose2D(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7159 inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
7160                         const RuntimeShape& output_shape, T* output_data) {
7161   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7162   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7163 
7164   const int d0 = input_shape.DimsData()[0];
7165   const int d1 = input_shape.DimsData()[1];
7166   const int kLines = 4;
7167   const int kSkipSize = (kLines - 1) * d1;
7168 
7169   const T* input = input_data;
7170 
7171   int i = 0;
7172   for (; i <= d0 - kLines; i += kLines) {
7173     T* output = output_data + i;
7174 
7175     const T* input_ptr = input;
7176     optimized_ops_preload_l1_keep(input_ptr);
7177     input_ptr += d1;
7178     optimized_ops_preload_l1_keep(input_ptr);
7179     input_ptr += d1;
7180     optimized_ops_preload_l1_keep(input_ptr);
7181     input_ptr += d1;
7182     optimized_ops_preload_l1_keep(input_ptr);
7183 
7184     int j = 0;
7185     for (; j <= d1 - kLines; j += kLines) {
7186       input_ptr = input;
7187       const T a00 = input_ptr[0];
7188       const T a01 = input_ptr[1];
7189       const T a02 = input_ptr[2];
7190       const T a03 = input_ptr[3];
7191       input_ptr += d1;
7192       const T a10 = input_ptr[0];
7193       const T a11 = input_ptr[1];
7194       const T a12 = input_ptr[2];
7195       const T a13 = input_ptr[3];
7196       input_ptr += d1;
7197       const T a20 = input_ptr[0];
7198       const T a21 = input_ptr[1];
7199       const T a22 = input_ptr[2];
7200       const T a23 = input_ptr[3];
7201       input_ptr += d1;
7202       const T a30 = input_ptr[0];
7203       const T a31 = input_ptr[1];
7204       const T a32 = input_ptr[2];
7205       const T a33 = input_ptr[3];
7206 
7207       output[0] = a00;
7208       output[1] = a10;
7209       output[2] = a20;
7210       output[3] = a30;
7211       output += d0;
7212 
7213       output[0] = a01;
7214       output[1] = a11;
7215       output[2] = a21;
7216       output[3] = a31;
7217       output += d0;
7218 
7219       output[0] = a02;
7220       output[1] = a12;
7221       output[2] = a22;
7222       output[3] = a32;
7223       output += d0;
7224 
7225       output[0] = a03;
7226       output[1] = a13;
7227       output[2] = a23;
7228       output[3] = a33;
7229       output += d0;
7230 
7231       input += kLines;
7232     }
7233     if (j == d1) {
7234       input += kSkipSize;
7235     } else {
7236       for (int p = 0; p < kLines; ++p) {
7237         for (int q = 0; q < d1 - j; ++q) {
7238           *(output + q * d0 + p) = *(input + p * d1 + q);
7239         }
7240       }
7241       input += (d1 - j) + kSkipSize;
7242     }
7243   }
7244   for (; i < d0; ++i) {
7245     T* output = output_data + i;
7246     for (int j = 0; j < d1; ++j) {
7247       *output = *input;
7248       output += d0;
7249       ++input;
7250     }
7251   }
7252 }
7253 
7254 template <>
Transpose2D(const RuntimeShape & input_shape,const int32_t * input_data,const RuntimeShape & output_shape,int32_t * output_data)7255 inline void Transpose2D(const RuntimeShape& input_shape,
7256                         const int32_t* input_data,
7257                         const RuntimeShape& output_shape,
7258                         int32_t* output_data) {
7259   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7260   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7261 
7262   const int d0 = input_shape.DimsData()[0];
7263   const int d1 = input_shape.DimsData()[1];
7264 #ifdef USE_NEON
7265   const int kLines = 4;
7266   const int kSkipSize = (kLines - 1) * d1;
7267 #endif
7268 
7269   const int32_t* input = input_data;
7270 
7271   int i = 0;
7272 #ifdef USE_NEON
7273   for (; i <= d0 - kLines; i += kLines) {
7274     int32_t* output = output_data + i;
7275 
7276     const int32_t* input_ptr = input;
7277     optimized_ops_preload_l1_keep(input_ptr);
7278     input_ptr += d1;
7279     optimized_ops_preload_l1_keep(input_ptr);
7280     input_ptr += d1;
7281     optimized_ops_preload_l1_keep(input_ptr);
7282     input_ptr += d1;
7283     optimized_ops_preload_l1_keep(input_ptr);
7284 
7285     int j = 0;
7286     for (; j <= d1 - kLines; j += kLines) {
7287       input_ptr = input;
7288       int32x4_t a0 = vld1q_s32(input);
7289       input_ptr += d1;
7290       int32x4_t a1 = vld1q_s32(input_ptr);
7291       input_ptr += d1;
7292       int32x4_t a2 = vld1q_s32(input_ptr);
7293       input_ptr += d1;
7294       int32x4_t a3 = vld1q_s32(input_ptr);
7295 
7296       int32x4x2_t tmp1 = vuzpq_s32(a0, a2);
7297       int32x4x2_t tmp2 = vuzpq_s32(a1, a3);
7298       int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]);
7299       int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]);
7300 
7301       vst1q_s32(output, tmp3.val[0]);
7302       output += d0;
7303       vst1q_s32(output, tmp4.val[0]);
7304       output += d0;
7305       vst1q_s32(output, tmp3.val[1]);
7306       output += d0;
7307       vst1q_s32(output, tmp4.val[1]);
7308       output += d0;
7309       input += kLines;
7310     }
7311     if (j == d1) {
7312       input += kSkipSize;
7313     } else {
7314       for (int p = 0; p < kLines; ++p) {
7315         for (int q = 0; q < d1 - j; ++q) {
7316           *(output + q * d0 + p) = *(input + p * d1 + q);
7317         }
7318       }
7319       input += (d1 - j) + kSkipSize;
7320     }
7321   }
7322 #endif
7323   for (; i < d0; ++i) {
7324     int32_t* output = output_data + i;
7325     for (int j = 0; j < d1; ++j) {
7326       *output = *input;
7327       output += d0;
7328       ++input;
7329     }
7330   }
7331 }
7332 
7333 // TODO(b/173718660): see if we can reduce the number
7334 // of lines of code in branching without affecting latency.
7335 template <typename T>
Transpose3D(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7336 inline void Transpose3D(const TransposeParams& params,
7337                         const RuntimeShape& input_shape, const T* input_data,
7338                         const RuntimeShape& output_shape, T* output_data) {
7339   int s1, s2, s3;
7340   s1 = input_shape.Dims(0);
7341   s2 = input_shape.Dims(1);
7342   s3 = input_shape.Dims(2);
7343 
7344   int p1, p2, p3;
7345   if (params.perm[0] == 2) {
7346     p1 = 1;
7347   } else if (params.perm[1] == 2) {
7348     p2 = 1;
7349   } else {
7350     p3 = 1;
7351   }
7352 
7353   if (params.perm[0] == 1) {
7354     p1 = s3;
7355   } else if (params.perm[1] == 1) {
7356     p2 = s3;
7357   } else {
7358     p3 = s3;
7359   }
7360 
7361   if (params.perm[0] == 0) {
7362     p1 = s2 * s3;
7363   } else if (params.perm[1] == 0) {
7364     p2 = s2 * s3;
7365   } else {
7366     p3 = s2 * s3;
7367   }
7368 
7369   int o_s[3];
7370   o_s[0] = input_shape.Dims(params.perm[0]);
7371   o_s[1] = input_shape.Dims(params.perm[1]);
7372   o_s[2] = input_shape.Dims(params.perm[2]);
7373 
7374   for (int i1 = 0; i1 < o_s[0]; ++i1) {
7375     for (int i2 = 0; i2 < o_s[1]; ++i2) {
7376       for (int i3 = 0; i3 < o_s[2]; ++i3) {
7377         const int i = i1 * p1 + i2 * p2 + i3 * p3;
7378         const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
7379         output_data[o] = input_data[i];
7380       }
7381     }
7382   }
7383 }
7384 
7385 template <typename T, int N>
TransposeImpl(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7386 void TransposeImpl(const TransposeParams& params,
7387                    const RuntimeShape& input_shape, const T* input_data,
7388                    const RuntimeShape& output_shape, T* output_data) {
7389   const int dims_cnt = input_shape.DimensionsCount();
7390 
7391   int dim0, dim1;
7392   if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
7393                                                &dim1)) {
7394     Transpose2D(RuntimeShape({dim0, dim1}), input_data,
7395                 RuntimeShape({dim1, dim0}), output_data);
7396     return;
7397   }
7398 
7399   // TODO(b/141217325): notably Eigen is better suited for
7400   // larger inputs whereas Transpose3D is generally
7401   // better for smaller ones.
7402   //
7403   // E.g. on Nexus 5, Eigen is better for size 96^3 and up
7404   // and Transpose3D is better for 72^3 and down.
7405   //
7406   // 96^3 is not mobile-friendly for certain usecases
7407   // (e.g. model used in beam search for seq2seq) but is in others.
7408   // Consider tradeoffs.
7409   if (dims_cnt == 3) {
7410     Transpose3D(params, input_shape, input_data, output_shape, output_data);
7411     return;
7412   }
7413 
7414   // Reroute to the reference version if an optimized method for the given data
7415   // is not available.
7416   reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
7417                                  output_data);
7418 }
7419 
7420 template <typename T, int N = 5>
Transpose(const TransposeParams & unshrinked_params,const RuntimeShape & unshrinked_input_shape,const T * input_data,const RuntimeShape & unshrinked_output_shape,T * output_data)7421 void Transpose(const TransposeParams& unshrinked_params,
7422                const RuntimeShape& unshrinked_input_shape, const T* input_data,
7423                const RuntimeShape& unshrinked_output_shape, T* output_data) {
7424   ruy::profiler::ScopeLabel label("Transpose");
7425 
7426   const int output_size = unshrinked_output_shape.DimensionsCount();
7427   TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
7428   TFLITE_DCHECK_LE(output_size, N);
7429   TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
7430 
7431   RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
7432   RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
7433   TransposeParams shrinked_params = unshrinked_params;
7434 
7435   // Reduce any dimensions that have one size. Lower transpose op usually
7436   // performs better since memory access patterns will be improved.
7437   transpose_utils::RemoveOneSizeDimensions(
7438       &shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
7439 
7440   // Handle identity cases.
7441   // TODO(b/140779653): Add an optimization pass in the conversion process to
7442   // remove transpose op nodes where they do nothing like the below one.
7443   bool identical = true;
7444   for (int i = 0; i < shrinked_params.perm_count; ++i) {
7445     if (shrinked_params.perm[i] != i) {
7446       identical = false;
7447       break;
7448     }
7449   }
7450   if (identical) {
7451     memcpy(output_data, input_data,
7452            unshrinked_input_shape.FlatSize() * sizeof(T));
7453     return;
7454   }
7455 
7456   // Reduce dimensions by flattening.
7457   if (shrinked_params.perm[0] == 0 && output_size >= 3) {
7458     RuntimeShape non_flatten_input_shape;
7459     RuntimeShape non_flatten_output_shape;
7460     TransposeParams non_flatten_params;
7461     const int total_size = shrinked_input_shape.FlatSize();
7462     const int non_flatten_size = transpose_utils::Flatten(
7463         shrinked_input_shape, shrinked_output_shape, shrinked_params,
7464         &non_flatten_input_shape, &non_flatten_output_shape,
7465         &non_flatten_params);
7466     TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
7467 
7468     for (int i = 0; i < total_size; i += non_flatten_size) {
7469       TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
7470                           input_data + i, non_flatten_output_shape,
7471                           output_data + i);
7472     }
7473     return;
7474   }
7475 
7476   // Call non-flattened case.
7477   TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
7478                       shrinked_output_shape, output_data);
7479 }
7480 
7481 // Assume input1 & input2 have the same scale & zero point.
MaximumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7482 inline void MaximumElementwise(int size, const ArithmeticParams& params,
7483                                const int8* input1_data, const int8* input2_data,
7484                                int8* output_data) {
7485   ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit");
7486   int i = 0;
7487 #ifdef USE_NEON
7488   for (; i <= size - 16; i += 16) {
7489     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7490     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7491     const int8x16_t max_data =
7492         vmaxq_s8(input1_val_original, input2_val_original);
7493     vst1q_s8(output_data + i, max_data);
7494   }
7495 #endif  // USE_NEON
7496   for (; i < size; ++i) {
7497     const int8 input1_val = input1_data[i];
7498     const int8 input2_val = input2_data[i];
7499     output_data[i] = std::max(input1_val, input2_val);
7500   }
7501 }
7502 
MaximumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7503 inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
7504                                    int8 input1_data, const int8* input2_data,
7505                                    int8* output_data) {
7506   ruy::profiler::ScopeLabel label("MaximumScalarBroadcastInt8/8bit");
7507   int i = 0;
7508 
7509 #ifdef USE_NEON
7510   const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7511   for (; i <= size - 16; i += 16) {
7512     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7513     const int8x16_t max_data =
7514         vmaxq_s8(input1_val_original, input2_val_original);
7515     vst1q_s8(output_data + i, max_data);
7516   }
7517 #endif  // USE_NEON
7518   for (; i < size; ++i) {
7519     const int8 input2_val = input2_data[i];
7520     output_data[i] = std::max(input1_data, input2_val);
7521   }
7522 }
7523 
7524 // Assume input1 & input2 have the same scale & zero point.
MinimumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7525 inline void MinimumElementwise(int size, const ArithmeticParams& params,
7526                                const int8* input1_data, const int8* input2_data,
7527                                int8* output_data) {
7528   ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit");
7529   int i = 0;
7530 #ifdef USE_NEON
7531   for (; i <= size - 16; i += 16) {
7532     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7533     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7534     const int8x16_t min_data =
7535         vminq_s8(input1_val_original, input2_val_original);
7536     vst1q_s8(output_data + i, min_data);
7537   }
7538 #endif  // USE_NEON
7539   for (; i < size; ++i) {
7540     const int8 input1_val = input1_data[i];
7541     const int8 input2_val = input2_data[i];
7542     output_data[i] = std::min(input1_val, input2_val);
7543   }
7544 }
7545 
MinimumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7546 inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
7547                                    int8 input1_data, const int8* input2_data,
7548                                    int8* output_data) {
7549   ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit");
7550   int i = 0;
7551 
7552 #ifdef USE_NEON
7553   const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7554   for (; i <= size - 16; i += 16) {
7555     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7556     const int8x16_t min_data =
7557         vminq_s8(input1_val_original, input2_val_original);
7558     vst1q_s8(output_data + i, min_data);
7559   }
7560 #endif  // USE_NEON
7561   for (; i < size; ++i) {
7562     const int8 input2_val = input2_data[i];
7563     output_data[i] = std::min(input1_data, input2_val);
7564   }
7565 }
7566 
7567 template <typename Op>
BroadcastMaximumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7568 inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
7569                                      const RuntimeShape& input1_shape,
7570                                      const int8* input1_data,
7571                                      const RuntimeShape& input2_shape,
7572                                      const int8* input2_data,
7573                                      const RuntimeShape& output_shape,
7574                                      int8* output_data, Op op) {
7575   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7576     return reference_ops::MaximumMinimumBroadcastSlow(
7577         input1_shape, input1_data, input2_shape, input2_data, output_shape,
7578         output_data, op);
7579   }
7580 
7581   BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7582                           input2_data, output_shape, output_data,
7583                           MaximumElementwise, MaximumScalarBroadcast);
7584 }
7585 
7586 template <typename Op>
BroadcastMinimumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7587 inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
7588                                      const RuntimeShape& input1_shape,
7589                                      const int8* input1_data,
7590                                      const RuntimeShape& input2_shape,
7591                                      const int8* input2_data,
7592                                      const RuntimeShape& output_shape,
7593                                      int8* output_data, Op op) {
7594   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7595     return reference_ops::MaximumMinimumBroadcastSlow(
7596         input1_shape, input1_data, input2_shape, input2_data, output_shape,
7597         output_data, op);
7598   }
7599 
7600   BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7601                           input2_data, output_shape, output_data,
7602                           MinimumElementwise, MinimumScalarBroadcast);
7603 }
7604 
7605 template <typename T>
CumsumImpl(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7606 void CumsumImpl(const T* input_data, const RuntimeShape& shape, int axis,
7607                 bool exclusive, bool reverse, T* output_data) {
7608   Eigen::array<Eigen::DenseIndex, 3> dims = {1, 1, 1};
7609 
7610   for (int i = 0; i < axis; ++i) {
7611     dims[0] *= shape.Dims(i);
7612   }
7613   dims[1] = shape.Dims(axis);
7614   for (int i = axis + 1; i < shape.DimensionsCount(); ++i) {
7615     dims[2] *= shape.Dims(i);
7616   }
7617 
7618   typedef Eigen::TensorMap<
7619       Eigen::Tensor<const T, 3, Eigen::RowMajor, Eigen::DenseIndex>,
7620       Eigen::Aligned>
7621       ConstTensor;
7622   typedef Eigen::TensorMap<
7623       Eigen::Tensor<T, 3, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
7624       Tensor;
7625   ConstTensor input(input_data, dims);
7626   Tensor output(output_data, dims);
7627 
7628   if (reverse) {
7629     Eigen::array<bool, 3> reverse_idx = {false, true, false};
7630     output =
7631         input.reverse(reverse_idx).cumsum(1, exclusive).reverse(reverse_idx);
7632   } else {
7633     output = input.cumsum(1, exclusive);
7634   }
7635 }
7636 
7637 template <typename T>
CumSum(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7638 void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
7639             bool exclusive, bool reverse, T* output_data) {
7640   const int dim = shape.DimensionsCount();
7641   TFLITE_DCHECK_GE(dim, 1);
7642   CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
7643 }
7644 
PReluScalarBroadcast(int size,const ArithmeticParams & params,float alpha,const float * input_data,float * output_data)7645 inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
7646                                  float alpha, const float* input_data,
7647                                  float* output_data) {
7648   ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
7649   int i = 0;
7650 
7651 #ifdef USE_NEON
7652   const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7653   const float32x4_t alpha_dup = vdupq_n_f32(alpha);
7654   for (; i <= size - 16; i += 16) {
7655     const float32x4_t input1 = vld1q_f32(input_data + i);
7656     const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7657     const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7658     const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7659 
7660     const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
7661     const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
7662     const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
7663     const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
7664 
7665     const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7666     const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7667     const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7668     const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7669 
7670     const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7671     vst1q_f32(output_data + i, result1);
7672     const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7673     vst1q_f32(output_data + i + 4, result2);
7674     const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7675     vst1q_f32(output_data + i + 8, result3);
7676     const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7677     vst1q_f32(output_data + i + 12, result4);
7678   }
7679 
7680   for (; i <= size - 4; i += 4) {
7681     const float32x4_t input = vld1q_f32(input_data + i);
7682     const float32x4_t temp = vmulq_f32(input, alpha_dup);
7683     const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7684     const float32x4_t result = vbslq_f32(mask, input, temp);
7685     vst1q_f32(output_data + i, result);
7686   }
7687 #endif  // USE_NEON
7688   for (; i < size; ++i) {
7689     const float input = input_data[i];
7690     output_data[i] = input >= 0.f ? input : input * alpha;
7691   }
7692 }
7693 
PReluElementWise(int flat_size,const ArithmeticParams & params,const float * alpha_data,const float * input_data,float * output_data)7694 inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
7695                              const float* alpha_data, const float* input_data,
7696                              float* output_data) {
7697   ruy::profiler::ScopeLabel label("PreluElementWise/float");
7698 
7699   int i = 0;
7700 #ifdef USE_NEON
7701   const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7702   for (; i <= flat_size - 16; i += 16) {
7703     const float32x4_t input1 = vld1q_f32(input_data + i);
7704     const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
7705     const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7706     const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
7707     const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7708     const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
7709     const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7710     const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
7711 
7712     const float32x4_t temp1 = vmulq_f32(input1, alpha1);
7713     const float32x4_t temp2 = vmulq_f32(input2, alpha2);
7714     const float32x4_t temp3 = vmulq_f32(input3, alpha3);
7715     const float32x4_t temp4 = vmulq_f32(input4, alpha4);
7716 
7717     const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7718     const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7719     const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7720     const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7721 
7722     const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7723     vst1q_f32(output_data + i, result1);
7724     const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7725     vst1q_f32(output_data + i + 4, result2);
7726     const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7727     vst1q_f32(output_data + i + 8, result3);
7728     const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7729     vst1q_f32(output_data + i + 12, result4);
7730   }
7731 
7732   for (; i <= flat_size - 4; i += 4) {
7733     const float32x4_t input = vld1q_f32(input_data + i);
7734     const float32x4_t alpha = vld1q_f32(alpha_data + i);
7735 
7736     const float32x4_t temp = vmulq_f32(input, alpha);
7737     const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7738     const float32x4_t result = vbslq_f32(mask, input, temp);
7739     vst1q_f32(output_data + i, result);
7740   }
7741 #endif  // USE_NEON
7742   for (; i < flat_size; ++i) {
7743     const float input = input_data[i];
7744     const float alpha = alpha_data[i];
7745     output_data[i] = input >= 0.f ? input : input * alpha;
7746   }
7747 }
7748 
BroadcastPReluDispatch(const ArithmeticParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & alpha_shape,const float * alpha_data,const RuntimeShape & output_shape,float * output_data,float (* func)(float,float))7749 inline void BroadcastPReluDispatch(
7750     const ArithmeticParams& params, const RuntimeShape& input_shape,
7751     const float* input_data, const RuntimeShape& alpha_shape,
7752     const float* alpha_data, const RuntimeShape& output_shape,
7753     float* output_data, float (*func)(float, float)) {
7754   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7755     return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
7756         input_shape, input_data, alpha_shape, alpha_data, output_shape,
7757         output_data, func);
7758   }
7759 
7760   BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
7761                           alpha_data, output_shape, output_data,
7762                           PReluElementWise, PReluScalarBroadcast);
7763 }
7764 
7765 // Returns the index with minimum value within `input_data`.
7766 // If there is a tie, returns the smaller index.
7767 template <typename T>
ArgMinVector(const T * input_data,int size)7768 inline int ArgMinVector(const T* input_data, int size) {
7769   T min_value = input_data[0];
7770   int min_index = 0;
7771   for (int i = 1; i < size; ++i) {
7772     const T curr_value = input_data[i];
7773     if (curr_value < min_value) {
7774       min_value = curr_value;
7775       min_index = i;
7776     }
7777   }
7778   return min_index;
7779 }
7780 
7781 // Returns the index with maximum value within `input_data`.
7782 // If there is a tie, returns the smaller index.
7783 template <typename T>
ArgMaxVector(const T * input_data,int size)7784 inline int ArgMaxVector(const T* input_data, int size) {
7785   T max_value = input_data[0];
7786   int max_index = 0;
7787   for (int i = 1; i < size; ++i) {
7788     const T curr_value = input_data[i];
7789     if (curr_value > max_value) {
7790       max_value = curr_value;
7791       max_index = i;
7792     }
7793   }
7794   return max_index;
7795 }
7796 
7797 template <>
ArgMinVector(const float * input_data,int size)7798 inline int ArgMinVector(const float* input_data, int size) {
7799   int32_t min_index = 0;
7800   float min_value = input_data[0];
7801   int32_t i = 1;
7802 #ifdef USE_NEON
7803   if (size >= 4) {
7804     float32x4_t min_value_f32x4 = vld1q_f32(input_data);
7805     const int32_t index_init[4] = {0, 1, 2, 3};
7806     int32x4_t min_index_s32x4 = vld1q_s32(index_init);
7807     int32x4_t index_s32x4 = min_index_s32x4;
7808     int32x4_t inc = vdupq_n_s32(4);
7809     for (i = 4; i <= size - 4; i += 4) {
7810       // Increase indices by 4.
7811       index_s32x4 = vaddq_s32(index_s32x4, inc);
7812       float32x4_t v = vld1q_f32(&input_data[i]);
7813       uint32x4_t mask = vcltq_f32(v, min_value_f32x4);
7814       min_value_f32x4 = vminq_f32(min_value_f32x4, v);
7815       min_index_s32x4 = vbslq_s32(mask, index_s32x4, min_index_s32x4);
7816     }
7817     // Find min element within float32x4_t.
7818 #ifdef __aarch64__
7819     min_value = vminvq_f32(min_value_f32x4);
7820 #else
7821     float32x2_t min_value_f32x2 = vpmin_f32(vget_low_f32(min_value_f32x4),
7822                                             vget_high_f32(min_value_f32x4));
7823     min_value_f32x2 = vpmin_f32(min_value_f32x2, min_value_f32x2);
7824     min_value = vget_lane_f32(min_value_f32x2, 0);
7825 #endif  // __aarch64__
7826     // Mask indices of non-min values with max int32_t.
7827     float32x4_t fill_min_value_f32x4 = vdupq_n_f32(min_value);
7828     uint32x4_t mask = vceqq_f32(min_value_f32x4, fill_min_value_f32x4);
7829     int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7830     min_index_s32x4 = vbslq_s32(mask, min_index_s32x4, all_set);
7831     // Find min index of min values.
7832 #ifdef __aarch64__
7833     min_index = vminvq_s32(min_index_s32x4);
7834 #else
7835     int32x2_t min_index_s32x2 = vpmin_s32(vget_low_s32(min_index_s32x4),
7836                                           vget_high_s32(min_index_s32x4));
7837     min_index_s32x2 = vpmin_s32(min_index_s32x2, min_index_s32x2);
7838     min_index = vget_lane_s32(min_index_s32x2, 0);
7839 #endif  // __aarch64__
7840   }
7841 #endif  // USE_NEON
7842   // Leftover loop.
7843   for (; i < size; ++i) {
7844     const float curr_value = input_data[i];
7845     if (curr_value < min_value) {
7846       min_value = curr_value;
7847       min_index = i;
7848     }
7849   }
7850   return min_index;
7851 }
7852 
7853 template <>
ArgMaxVector(const float * input_data,int size)7854 inline int ArgMaxVector(const float* input_data, int size) {
7855   int32_t max_index = 0;
7856   float max_value = input_data[0];
7857   int32_t i = 1;
7858 #ifdef USE_NEON
7859   if (size >= 4) {
7860     float32x4_t max_value_f32x4 = vld1q_f32(input_data);
7861     const int32_t index_init[4] = {0, 1, 2, 3};
7862     int32x4_t max_index_s32x4 = vld1q_s32(index_init);
7863     int32x4_t index_s32x4 = max_index_s32x4;
7864     int32x4_t inc = vdupq_n_s32(4);
7865     for (i = 4; i <= size - 4; i += 4) {
7866       // Increase indices by 4.
7867       index_s32x4 = vaddq_s32(index_s32x4, inc);
7868       float32x4_t v = vld1q_f32(&input_data[i]);
7869       uint32x4_t mask = vcgtq_f32(v, max_value_f32x4);
7870       max_value_f32x4 = vmaxq_f32(max_value_f32x4, v);
7871       max_index_s32x4 = vbslq_s32(mask, index_s32x4, max_index_s32x4);
7872     }
7873     // Find max element within float32x4_t.
7874 #ifdef __aarch64__
7875     max_value = vmaxvq_f32(max_value_f32x4);
7876 #else
7877     float32x2_t max_value_f32x2 = vpmax_f32(vget_low_f32(max_value_f32x4),
7878                                             vget_high_f32(max_value_f32x4));
7879     max_value_f32x2 = vpmax_f32(max_value_f32x2, max_value_f32x2);
7880     max_value = vget_lane_f32(max_value_f32x2, 0);
7881 #endif  // __aarch64__
7882     // Mask indices of non-max values with max int32_t.
7883     float32x4_t fill_max_value_f32x4 = vdupq_n_f32(max_value);
7884     uint32x4_t mask = vceqq_f32(max_value_f32x4, fill_max_value_f32x4);
7885     int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7886     max_index_s32x4 = vbslq_s32(mask, max_index_s32x4, all_set);
7887     // Find min index of max values.
7888 #ifdef __aarch64__
7889     max_index = vminvq_s32(max_index_s32x4);
7890 #else
7891     int32x2_t max_index_s32x2 = vpmin_s32(vget_low_s32(max_index_s32x4),
7892                                           vget_high_s32(max_index_s32x4));
7893     max_index_s32x2 = vpmin_s32(max_index_s32x2, max_index_s32x2);
7894     max_index = vget_lane_s32(max_index_s32x2, 0);
7895 #endif  // __aarch64__
7896   }
7897 #endif  // USE_NEON
7898   // Leftover loop.
7899   for (; i < size; ++i) {
7900     const float curr_value = input_data[i];
7901     if (curr_value > max_value) {
7902       max_value = curr_value;
7903       max_index = i;
7904     }
7905   }
7906   return max_index;
7907 }
7908 
7909 template <>
ArgMaxVector(const int8_t * input_data,int size)7910 inline int ArgMaxVector(const int8_t* input_data, int size) {
7911   int32_t max_index = 0;
7912   int8_t max_value = input_data[0];
7913   int32_t i = 0;
7914 #ifdef USE_NEON
7915   constexpr int VECTOR_SIZE = 16;
7916   if (size >= VECTOR_SIZE) {
7917     int8x16_t max_value_s8x16;
7918     for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7919       max_value_s8x16 = vld1q_s8(input_data + i);
7920       int8_t max_from_vec;
7921 #ifdef __aarch64__
7922       max_from_vec = vmaxvq_s8(max_value_s8x16);
7923 #else   // 32 bit
7924       int8x8_t max_val_s8x8 =
7925           vpmax_s8(vget_low_s8(max_value_s8x16), vget_high_s8(max_value_s8x16));
7926       max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7927       max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7928       max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7929       max_from_vec = vget_lane_s8(max_val_s8x8, 0);
7930 #endif  // __aarch64__
7931       if (max_from_vec > max_value) {
7932         max_value = max_from_vec;
7933         max_index = i;
7934       }
7935     }
7936   }
7937   for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7938        start_idx++) {
7939     if (input_data[start_idx] == max_value) {
7940       max_index = start_idx;
7941       break;
7942     }
7943   }
7944 
7945 #endif  // USE_NEON
7946   // Leftover loop.
7947   for (; i < size; ++i) {
7948     const int8_t curr_value = input_data[i];
7949     if (curr_value > max_value) {
7950       max_value = curr_value;
7951       max_index = i;
7952     }
7953   }
7954 
7955   return max_index;
7956 }
7957 
7958 template <>
ArgMaxVector(const uint8_t * input_data,int size)7959 inline int ArgMaxVector(const uint8_t* input_data, int size) {
7960   int32_t max_index = 0;
7961   uint8_t max_value = input_data[0];
7962   int32_t i = 0;
7963 #ifdef USE_NEON
7964   constexpr int VECTOR_SIZE = 16;
7965   if (size >= VECTOR_SIZE) {
7966     uint8x16_t max_value_u8x16;
7967     for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7968       max_value_u8x16 = vld1q_u8(input_data + i);
7969       uint8_t max_from_vec;
7970 #ifdef __aarch64__
7971       max_from_vec = vmaxvq_u8(max_value_u8x16);
7972 #else   // 32 bit
7973       uint8x8_t max_val_u8x8 =
7974           vpmax_u8(vget_low_u8(max_value_u8x16), vget_high_u8(max_value_u8x16));
7975       max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7976       max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7977       max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7978       max_from_vec = vget_lane_u8(max_val_u8x8, 0);
7979 #endif  // __aarch64__
7980       if (max_from_vec > max_value) {
7981         max_value = max_from_vec;
7982         max_index = i;
7983       }
7984     }
7985   }
7986   for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7987        start_idx++) {
7988     if (input_data[start_idx] == max_value) {
7989       max_index = start_idx;
7990       break;
7991     }
7992   }
7993 
7994 #endif  // USE_NEON
7995   // Leftover loop.
7996   for (; i < size; ++i) {
7997     const uint8_t curr_value = input_data[i];
7998     if (curr_value > max_value) {
7999       max_value = curr_value;
8000       max_index = i;
8001     }
8002   }
8003 
8004   return max_index;
8005 }
8006 
8007 // Specializes ArgMinMax function with axis=dims-1.
8008 // In this case, ArgMinMax reduction is applied on contiguous memory.
8009 template <typename T1, typename T2, bool is_arg_max>
ArgMinMaxLastAxis(const RuntimeShape & input_shape,const T1 * input_data,const RuntimeShape & output_shape,T2 * output_data)8010 inline void ArgMinMaxLastAxis(const RuntimeShape& input_shape,
8011                               const T1* input_data,
8012                               const RuntimeShape& output_shape,
8013                               T2* output_data) {
8014   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
8015   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 1);
8016   TFLITE_DCHECK_EQ(input_shape.Dims(0), output_shape.Dims(0));
8017 
8018   int outer_size = input_shape.Dims(0);
8019   int axis_size = input_shape.Dims(1);
8020   for (int outer = 0; outer < outer_size; ++outer) {
8021     if (is_arg_max) {
8022       output_data[outer] = static_cast<T2>(
8023           ArgMaxVector<T1>(input_data + outer * axis_size, axis_size));
8024     } else {
8025       output_data[outer] = static_cast<T2>(
8026           ArgMinVector<T1>(input_data + outer * axis_size, axis_size));
8027     }
8028   }
8029 }
8030 
8031 template <typename T1, typename T2, typename T3>
ArgMinMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data,const bool is_arg_max)8032 inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
8033                       const T3* input2_data, const RuntimeShape& output_shape,
8034                       T2* output_data, const bool is_arg_max) {
8035   ruy::profiler::ScopeLabel label("ArgMinMax");
8036 
8037   TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
8038   TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
8039                    output_shape.DimensionsCount());
8040   int axis = input2_data[0];
8041   if (axis < 0) {
8042     axis += input1_shape.DimensionsCount();
8043   }
8044   const int axis_size = input1_shape.Dims(axis);
8045 
8046   int outer_size = 1;
8047   for (int i = 0; i < axis; ++i) {
8048     TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
8049     outer_size *= input1_shape.Dims(i);
8050   }
8051 
8052   int inner_size = 1;
8053   const int dims_count = input1_shape.DimensionsCount();
8054   for (int i = axis + 1; i < dims_count; ++i) {
8055     TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
8056     inner_size *= input1_shape.Dims(i);
8057   }
8058 
8059   // Call specialized function when axis=dims-1. So far, only float32 is
8060   // optimized so reroute to specialized function only when T1 is float32.
8061   if (inner_size == 1 &&
8062       (std::is_same<T1, float>::value || std::is_same<T1, int8_t>::value ||
8063        std::is_same<T1, uint8_t>::value)) {
8064     if (is_arg_max) {
8065       ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/true>(
8066           {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8067     } else {
8068       ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/false>(
8069           {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8070     }
8071     return;
8072   }
8073 
8074   reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, output_shape,
8075                            output_data, is_arg_max);
8076 }
8077 
8078 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8079 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8080             const T3* input2_data, const RuntimeShape& output_shape,
8081             T2* output_data) {
8082   ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
8083             /*is_arg_max=*/true);
8084 }
8085 
8086 // Convenience version that allows, for example, generated-code calls to be
8087 // the same as other binary ops.
8088 // For backward compatibility, reference_ops has ArgMax function.
8089 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8090 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8091                    const RuntimeShape& input2_shape, const T3* input2_data,
8092                    const RuntimeShape& output_shape, T2* output_data) {
8093   // Drop shape of second input: not needed.
8094   ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
8095 }
8096 
Conv3D(const Conv3DParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,const RuntimeShape & transposed_filter_shape,float * transposed_filter_data,CpuBackendContext * cpu_backend_context)8097 inline void Conv3D(const Conv3DParams& params, const RuntimeShape& input_shape,
8098                    const float* input_data, const RuntimeShape& filter_shape,
8099                    const float* filter_data, const RuntimeShape& bias_shape,
8100                    const float* bias_data, const RuntimeShape& output_shape,
8101                    float* output_data, const RuntimeShape& im2col_shape,
8102                    float* im2col_data,
8103                    const RuntimeShape& transposed_filter_shape,
8104                    float* transposed_filter_data,
8105                    CpuBackendContext* cpu_backend_context) {
8106   const int stride_depth = params.stride_depth;
8107   const int stride_height = params.stride_height;
8108   const int stride_width = params.stride_width;
8109   const int dilation_depth_factor = params.dilation_depth;
8110   const int dilation_height_factor = params.dilation_height;
8111   const int dilation_width_factor = params.dilation_width;
8112   const float output_activation_min = params.float_activation_min;
8113   const float output_activation_max = params.float_activation_max;
8114   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8115   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8116   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8117 
8118   ruy::profiler::ScopeLabel label("Conv3D");
8119 
8120   // NB: the float 0.0f value is represented by all zero bytes.
8121   const uint8 float_zero_byte = 0x00;
8122   const float* gemm_input_data = nullptr;
8123   const RuntimeShape* gemm_input_shape = nullptr;
8124   const int filter_width = filter_shape.Dims(2);
8125   const int filter_height = filter_shape.Dims(1);
8126   const int filter_depth = filter_shape.Dims(0);
8127   const bool need_dilated_im2col = dilation_width_factor != 1 ||
8128                                    dilation_height_factor != 1 ||
8129                                    dilation_depth_factor != 1;
8130   const bool need_im2col = stride_depth != 1 || stride_height != 1 ||
8131                            stride_width != 1 || filter_depth != 1 ||
8132                            filter_height != 1 || filter_width != 1;
8133 
8134   if (need_dilated_im2col) {
8135     DilatedIm2col3D(params, filter_depth, filter_height, filter_width,
8136                     float_zero_byte, input_shape, input_data, im2col_shape,
8137                     im2col_data);
8138     gemm_input_data = im2col_data;
8139     gemm_input_shape = &im2col_shape;
8140   } else if (need_im2col) {
8141     TFLITE_DCHECK(im2col_data);
8142     Im2col3D(params, filter_depth, filter_height, filter_width, float_zero_byte,
8143              input_shape, input_data, im2col_shape, im2col_data);
8144     gemm_input_data = im2col_data;
8145     gemm_input_shape = &im2col_shape;
8146   } else {
8147     TFLITE_DCHECK(!im2col_data);
8148     gemm_input_data = input_data;
8149     gemm_input_shape = &input_shape;
8150   }
8151 
8152   // Transpose the filter tensor.
8153   TransposeParams transpose_params;
8154   transpose_params.perm_count = 5;
8155   transpose_params.perm[0] = 4;
8156   transpose_params.perm[1] = 0;
8157   transpose_params.perm[2] = 1;
8158   transpose_params.perm[3] = 2;
8159   transpose_params.perm[4] = 3;
8160   Transpose<float, 5>(transpose_params, filter_shape, filter_data,
8161                       transposed_filter_shape, transposed_filter_data);
8162 
8163   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
8164   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
8165   int n = output_shape.Dims(4);
8166   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
8167 
8168   cpu_backend_gemm::MatrixParams<float> lhs_params;
8169   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8170   lhs_params.rows = n;
8171   lhs_params.cols = k;
8172   cpu_backend_gemm::MatrixParams<float> rhs_params;
8173   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8174   rhs_params.rows = k;
8175   rhs_params.cols = m;
8176   cpu_backend_gemm::MatrixParams<float> dst_params;
8177   dst_params.order = cpu_backend_gemm::Order::kColMajor;
8178   dst_params.rows = n;
8179   dst_params.cols = m;
8180   cpu_backend_gemm::GemmParams<float, float> gemm_params;
8181   gemm_params.bias = bias_data;
8182   gemm_params.clamp_min = output_activation_min;
8183   gemm_params.clamp_max = output_activation_max;
8184   cpu_backend_gemm::Gemm(lhs_params, transposed_filter_data, rhs_params,
8185                          gemm_input_data, dst_params, output_data, gemm_params,
8186                          cpu_backend_context);
8187 }
8188 
8189 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
8190 // order (planes, height, width, channel), constructed from patches in
8191 // 'col_data', which is required to be in storage order (out_planes * out_height
8192 // * out_width, filter_planes, filter_height, filter_width, in_channel).
8193 //
8194 // This function is copied from tensorflow/core/kernels/conv_grad_ops_3d.cc
8195 // authored by Eugene Zhulenev(ezhulenev).
8196 template <typename T>
Col2im(const T * col_data,const int channel,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)8197 void Col2im(const T* col_data, const int channel, const int planes,
8198             const int height, const int width, const int filter_p,
8199             const int filter_h, const int filter_w, const int pad_pt,
8200             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
8201             const int pad_r, const int stride_p, const int stride_h,
8202             const int stride_w, T* im_data) {
8203   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
8204   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
8205   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
8206   int p_pad = -pad_pt;
8207   for (int p = 0; p < planes_col; ++p) {
8208     int h_pad = -pad_t;
8209     for (int h = 0; h < height_col; ++h) {
8210       int w_pad = -pad_l;
8211       for (int w = 0; w < width_col; ++w) {
8212         T* im_patch_data =
8213             im_data +
8214             (p_pad * height * width + h_pad * width + w_pad) * channel;
8215         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
8216           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
8217             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
8218               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
8219                   iw < width) {
8220                 for (int i = 0; i < channel; ++i) {
8221                   im_patch_data[i] += col_data[i];
8222                 }
8223               }
8224               im_patch_data += channel;
8225               col_data += channel;
8226             }
8227             // Jump over remaining number of channel.
8228             im_patch_data += channel * (width - filter_w);
8229           }
8230           // Jump over remaining number of (channel * width).
8231           im_patch_data += (channel * width) * (height - filter_h);
8232         }
8233         w_pad += stride_w;
8234       }
8235       h_pad += stride_h;
8236     }
8237     p_pad += stride_p;
8238   }
8239 }
8240 
8241 template <typename T>
BiasAdd3D(T * im_data,const T * bias_data,const RuntimeShape & input_shape,float float_activation_min,float float_activation_max)8242 void BiasAdd3D(T* im_data, const T* bias_data, const RuntimeShape& input_shape,
8243                float float_activation_min, float float_activation_max) {
8244   if (bias_data) {
8245     const int outer_size = input_shape.Dims(0) * input_shape.Dims(1) *
8246                            input_shape.Dims(2) * input_shape.Dims(3);
8247     const int num_channels = input_shape.Dims(4);
8248     for (int n = 0; n < outer_size; ++n) {
8249       for (int c = 0; c < num_channels; ++c) {
8250         im_data[c] = ActivationFunctionWithMinMax(im_data[c] + bias_data[c],
8251                                                   float_activation_min,
8252                                                   float_activation_max);
8253       }
8254       im_data += num_channels;
8255     }
8256   } else {
8257     const int flat_size = input_shape.FlatSize();
8258     for (int i = 0; i < flat_size; ++i) {
8259       im_data[i] = ActivationFunctionWithMinMax(
8260           im_data[i], float_activation_min, float_activation_max);
8261     }
8262   }
8263 }
8264 
Conv3DTranspose(const Conv3DTransposeParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)8265 inline void Conv3DTranspose(
8266     const Conv3DTransposeParams& params, const RuntimeShape& input_shape,
8267     const float* input_data, const RuntimeShape& filter_shape,
8268     const float* filter_data, const RuntimeShape& bias_shape,
8269     const float* bias_data, const RuntimeShape& output_shape,
8270     float* const output_data, const RuntimeShape& col2im_shape,
8271     float* col2im_data, CpuBackendContext* cpu_backend_context) {
8272   ruy::profiler::ScopeLabel label("Conv3DTranspose/float");
8273   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8274   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8275   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8276   TFLITE_DCHECK(col2im_data);
8277 
8278   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
8279   const int input_channel = MatchingDim(input_shape, 4, filter_shape, 4);
8280   const int output_channel = MatchingDim(output_shape, 4, filter_shape, 3);
8281   const int input_spatial_size =
8282       input_shape.Dims(1) * input_shape.Dims(2) * input_shape.Dims(3);
8283   const int output_spatial_size =
8284       output_shape.Dims(1) * output_shape.Dims(2) * output_shape.Dims(3);
8285 
8286   const int output_spatial_dim_1 = output_shape.Dims(1);
8287   const int output_spatial_dim_2 = output_shape.Dims(2);
8288   const int output_spatial_dim_3 = output_shape.Dims(3);
8289   const int input_offset = input_spatial_size * input_channel;
8290   const int output_offset = output_spatial_size * output_channel;
8291 
8292   const int filter_spatial_dim_1 = filter_shape.Dims(0);
8293   const int filter_spatial_dim_2 = filter_shape.Dims(1);
8294   const int filter_spatial_dim_3 = filter_shape.Dims(2);
8295 
8296   const int spatial_dim_1_padding_before = params.padding_values.depth;
8297   const int spatial_dim_1_padding_after =
8298       params.padding_values.height + params.padding_values.depth_offset;
8299   const int spatial_dim_2_padding_before = params.padding_values.height;
8300   const int spatial_dim_2_padding_after =
8301       params.padding_values.height + params.padding_values.height_offset;
8302   const int spatial_dim_3_padding_before = params.padding_values.width;
8303   const int spatial_dim_3_padding_after =
8304       params.padding_values.width + params.padding_values.width_offset;
8305   const int spatial_dim_1_stride = params.stride_depth;
8306   const int spatial_dim_2_stride = params.stride_height;
8307   const int spatial_dim_3_stride = params.stride_width;
8308   const int filter_total_size = filter_spatial_dim_1 * filter_spatial_dim_2 *
8309                                 filter_spatial_dim_3 * output_channel;
8310 
8311   cpu_backend_gemm::MatrixParams<float> lhs_params;
8312   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8313   lhs_params.rows = filter_total_size;
8314   lhs_params.cols = input_channel;
8315   float* output_data_p = output_data;
8316   std::fill_n(output_data, output_offset * batch_size, 0.0f);
8317   for (int i = 0; i < batch_size; ++i) {
8318     cpu_backend_gemm::MatrixParams<float> rhs_params;
8319     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8320     rhs_params.rows = input_channel;
8321     rhs_params.cols = input_spatial_size;
8322     cpu_backend_gemm::MatrixParams<float> dst_params;
8323     dst_params.order = cpu_backend_gemm::Order::kColMajor;
8324     dst_params.rows = filter_total_size;
8325     dst_params.cols = input_spatial_size;
8326     cpu_backend_gemm::GemmParams<float, float> gemm_params;
8327     cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params,
8328                            input_data + input_offset * i, dst_params,
8329                            col2im_data, gemm_params, cpu_backend_context);
8330 
8331     Col2im(col2im_data, output_channel, output_spatial_dim_1,
8332            output_spatial_dim_2, output_spatial_dim_3, filter_spatial_dim_1,
8333            filter_spatial_dim_2, filter_spatial_dim_3,
8334            spatial_dim_1_padding_before, spatial_dim_2_padding_before,
8335            spatial_dim_3_padding_before, spatial_dim_1_padding_after,
8336            spatial_dim_2_padding_after, spatial_dim_3_padding_after,
8337            spatial_dim_1_stride, spatial_dim_2_stride, spatial_dim_3_stride,
8338            output_data_p);
8339     output_data_p += output_offset;
8340   }
8341   output_data_p = output_data;
8342   BiasAdd3D(output_data_p, bias_data, output_shape, params.float_activation_min,
8343             params.float_activation_max);
8344 }
8345 
8346 }  // namespace optimized_ops
8347 }  // namespace tflite
8348 
8349 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8350 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8351 #pragma GCC diagnostic pop
8352 #endif
8353 
8354 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
8355