• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17 
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21 #include <algorithm>
22 #include <cmath>
23 #include <cstdint>
24 #include <limits>
25 #include <memory>
26 #include <tuple>
27 #include <type_traits>
28 
29 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
30 #include <Accelerate/Accelerate.h>
31 #endif
32 
33 #include "Eigen/Core"
34 #include "fixedpoint/fixedpoint.h"
35 #include "public/gemmlowp.h"
36 #include "tensorflow/lite/kernels/internal/common.h"
37 #include "tensorflow/lite/kernels/internal/quantization_util.h"
38 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
39 #include "tensorflow/lite/kernels/internal/round.h"
40 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
41 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
42 #include "tensorflow/lite/kernels/internal/types.h"
43 #include "unsupported/Eigen/CXX11/Tensor"
44 
45 namespace tflite {
46 namespace optimized_ops {
47 
48 // Unoptimized reference ops:
49 using reference_ops::ArgMax;
50 using reference_ops::ArgMinMax;
51 using reference_ops::Broadcast4DSlowGreater;
52 using reference_ops::Broadcast4DSlowGreaterEqual;
53 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
54 using reference_ops::Broadcast4DSlowGreaterWithScaling;
55 using reference_ops::Broadcast4DSlowLess;
56 using reference_ops::Broadcast4DSlowLessEqual;
57 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
58 using reference_ops::Broadcast4DSlowLessWithScaling;
59 using reference_ops::BroadcastAdd4DSlow;
60 using reference_ops::BroadcastMul4DSlow;
61 using reference_ops::BroadcastSub4DSlow;
62 using reference_ops::Concatenation;
63 using reference_ops::ConcatenationWithScaling;
64 using reference_ops::DepthConcatenation;
65 using reference_ops::Dequantize;
66 using reference_ops::Div;
67 using reference_ops::Elu;
68 using reference_ops::FakeQuant;
69 using reference_ops::Fill;
70 using reference_ops::Gather;
71 using reference_ops::Greater;
72 using reference_ops::GreaterEqual;
73 using reference_ops::GreaterEqualWithScaling;
74 using reference_ops::GreaterWithScaling;
75 using reference_ops::LeakyRelu;
76 using reference_ops::Less;
77 using reference_ops::LessEqual;
78 using reference_ops::LessEqualWithScaling;
79 using reference_ops::LessWithScaling;
80 using reference_ops::Mean;
81 using reference_ops::ProcessBroadcastShapes;
82 using reference_ops::RankOneSelect;
83 using reference_ops::Relu1;
84 using reference_ops::Relu6;
85 using reference_ops::ReluX;
86 using reference_ops::Select;
87 using reference_ops::SpaceToBatchND;
88 using reference_ops::Split;
89 using reference_ops::StridedSlice;
90 using reference_ops::Sub16;
91 using reference_ops::Transpose;
92 
93 // TODO(b/80247582) Remove this constant.
94 // This will be phased out as the shifts are revised with more thought. Use of a
95 // constant enables us to track progress on this work.
96 //
97 // Used to convert from old-style shifts (right) to new-style (left).
98 static constexpr int kReverseShift = -1;
99 
100 // Make a local VectorMap typedef allowing to map a float array
101 // as a Eigen vector expression. The std::conditional here is to
102 // construct the suitable Eigen type for the constness of the
103 // data. Indeed, for const data, we need to produce
104 //    Eigen::Map<const Eigen::Matrix<float, ...>>
105 // and not the more straightforward
106 //    Eigen::Map<Eigen::Matrix<const float, ...>>
107 template <typename Scalar>
108 using VectorMap = typename std::conditional<
109     std::is_const<Scalar>::value,
110     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
111                                    Eigen::Dynamic, 1>>,
112     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
113 
114 template <typename Scalar>
MapAsVector(Scalar * data,const RuntimeShape & shape)115 VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
116   const int size = shape.FlatSize();
117   return VectorMap<Scalar>(data, size, 1);
118 }
119 
120 // Make a local VectorMap typedef allowing to map a float array
121 // as a Eigen matrix expression. The same explanation as for VectorMap
122 // above also applies here.
123 template <typename Scalar>
124 using MatrixMap = typename std::conditional<
125     std::is_const<Scalar>::value,
126     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
127                                    Eigen::Dynamic, Eigen::Dynamic>>,
128     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
129 
130 template <typename Scalar>
MapAsMatrixWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)131 MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
132                                                const RuntimeShape& shape) {
133   const int dims_count = shape.DimensionsCount();
134   const int rows = shape.Dims(dims_count - 1);
135   const int cols = FlatSizeSkipDim(shape, dims_count - 1);
136   return MatrixMap<Scalar>(data, rows, cols);
137 }
138 
139 template <typename Scalar>
MapAsMatrixWithFirstDimAsCols(Scalar * data,const RuntimeShape & shape)140 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
141                                                 const RuntimeShape& shape) {
142   const int cols = shape.Dims(0);
143   const int rows = FlatSizeSkipDim(shape, 0);
144   return MatrixMap<Scalar>(data, rows, cols);
145 }
146 
147 template <typename Scalar>
148 using ArrayMap = typename std::conditional<
149     std::is_const<Scalar>::value,
150     Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
151                                   Eigen::Dynamic, Eigen::Dynamic>>,
152     Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
153 
154 template <typename Scalar>
MapAsArrayWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)155 ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
156                                              const RuntimeShape& shape) {
157   const int dims_count = shape.DimensionsCount();
158   const int rows = shape.Dims(dims_count - 1);
159   const int cols = FlatSizeSkipDim(shape, dims_count - 1);
160   return ArrayMap<Scalar>(data, rows, cols);
161 }
162 
163 // Copied from tensorflow/core/framework/tensor_types.h
164 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
165 struct TTypes {
166   // Rank-1 tensor (vector) of scalar type T.
167   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
168                            Eigen::Aligned>
169       Flat;
170   typedef Eigen::TensorMap<
171       Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
172       UnalignedConstMatrix;
173 };
174 
175 // TODO(b/62193649): this function is only needed as long
176 // as we have the --variable_batch hack.
177 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)178 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
179                                                    const RuntimeShape& shape,
180                                                    int rows) {
181   const int flatsize = shape.FlatSize();
182   TFLITE_DCHECK_EQ(flatsize % rows, 0);
183   const int cols = flatsize / rows;
184   return MatrixMap<Scalar>(data, rows, cols);
185 }
186 
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)187 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
188                                              float output_activation_max,
189                                              const RuntimeShape& bias_shape,
190                                              const float* bias_data,
191                                              const RuntimeShape& array_shape,
192                                              float* array_data) {
193 #ifdef USE_NEON
194   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
195   const int bias_size = bias_shape.FlatSize();
196   const int array_size = array_shape.FlatSize();
197   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
198   float* array_ptr = array_data;
199   float* array_end_ptr = array_ptr + array_size;
200   const auto activation_min = vdupq_n_f32(output_activation_min);
201   const auto activation_max = vdupq_n_f32(output_activation_max);
202   for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
203     int i = 0;
204     for (; i <= bias_size - 16; i += 16) {
205       auto b0 = vld1q_f32(bias_data + i);
206       auto b1 = vld1q_f32(bias_data + i + 4);
207       auto b2 = vld1q_f32(bias_data + i + 8);
208       auto b3 = vld1q_f32(bias_data + i + 12);
209       auto a0 = vld1q_f32(array_ptr + i);
210       auto a1 = vld1q_f32(array_ptr + i + 4);
211       auto a2 = vld1q_f32(array_ptr + i + 8);
212       auto a3 = vld1q_f32(array_ptr + i + 12);
213       auto x0 = vaddq_f32(a0, b0);
214       auto x1 = vaddq_f32(a1, b1);
215       auto x2 = vaddq_f32(a2, b2);
216       auto x3 = vaddq_f32(a3, b3);
217       x0 = vmaxq_f32(activation_min, x0);
218       x1 = vmaxq_f32(activation_min, x1);
219       x2 = vmaxq_f32(activation_min, x2);
220       x3 = vmaxq_f32(activation_min, x3);
221       x0 = vminq_f32(activation_max, x0);
222       x1 = vminq_f32(activation_max, x1);
223       x2 = vminq_f32(activation_max, x2);
224       x3 = vminq_f32(activation_max, x3);
225       vst1q_f32(array_ptr + i, x0);
226       vst1q_f32(array_ptr + i + 4, x1);
227       vst1q_f32(array_ptr + i + 8, x2);
228       vst1q_f32(array_ptr + i + 12, x3);
229     }
230     for (; i <= bias_size - 4; i += 4) {
231       auto b = vld1q_f32(bias_data + i);
232       auto a = vld1q_f32(array_ptr + i);
233       auto x = vaddq_f32(a, b);
234       x = vmaxq_f32(activation_min, x);
235       x = vminq_f32(activation_max, x);
236       vst1q_f32(array_ptr + i, x);
237     }
238     for (; i < bias_size; i++) {
239       array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
240                                                   output_activation_min,
241                                                   output_activation_max);
242     }
243   }
244 #else  // not NEON
245   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
246   const int bias_size = bias_shape.FlatSize();
247   const int array_size = array_shape.FlatSize();
248   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
249   for (int array_offset = 0; array_offset < array_size;
250        array_offset += bias_size) {
251     for (int i = 0; i < bias_size; i++) {
252       array_data[array_offset + i] = ActivationFunctionWithMinMax(
253           array_data[array_offset + i] + bias_data[i], output_activation_min,
254           output_activation_max);
255     }
256   }
257 #endif
258 }
259 
260 template <typename Lhs, typename Rhs, typename Result>
Gemm(const Eigen::MatrixBase<Lhs> & lhs,const Eigen::MatrixBase<Rhs> & rhs,Eigen::MatrixBase<Result> * result)261 void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
262           Eigen::MatrixBase<Result>* result) {
263   if (rhs.cols() == 1) {
264     gemmlowp::ScopedProfilingLabel label("GEMV");
265     result->col(0).noalias() = lhs * rhs.col(0);
266   } else {
267     gemmlowp::ScopedProfilingLabel label("GEMM");
268     result->noalias() = lhs * rhs;
269   }
270 }
271 
optimized_ops_preload_l1_stream(const uint8 * ptr)272 inline void optimized_ops_preload_l1_stream(const uint8* ptr) {
273 #ifdef GEMMLOWP_ARM_64
274   asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
275 #else
276   gemmlowp::Prefetch(ptr);
277 #endif
278 }
279 
optimized_ops_preload_l1_keep(const uint8 * ptr)280 inline void optimized_ops_preload_l1_keep(const uint8* ptr) {
281 #ifdef GEMMLOWP_ARM_64
282   asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
283 #else
284   gemmlowp::Prefetch(ptr);
285 #endif
286 }
287 
288 #ifdef GEMMLOWP_NEON
289 // In the common case of batch size 1, a fully-connected node degenerates
290 // to a matrix*vector product. LSTM cells contain a fully-connected node;
291 // when quantized, this becomes a special type of GEMV operation where
292 // the output is 16bit-quantized, thus needs its own special path.
GEMVForLstmCell(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,uint8 weights_zero_point,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)293 inline void GEMVForLstmCell(const RuntimeShape& input_shape,
294                             const uint8* input_data,
295                             const RuntimeShape& weights_shape,
296                             const uint8* weights_data, uint8 weights_zero_point,
297                             const RuntimeShape& bias_shape,
298                             const int32* bias_data, int32 accum_multiplier,
299                             int accum_shift, const RuntimeShape& output_shape,
300                             int16* output_data) {
301   gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell");
302   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
303   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
304   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
305   const int output_dim_count = output_shape.DimensionsCount();
306   const int weights_dim_count = weights_shape.DimensionsCount();
307   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
308   const int input_size = FlatSizeSkipDim(input_shape, 0);
309   const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
310                                       output_shape, output_dim_count - 1);
311   // This special fast path for quantized LSTM cells does not try to support
312   // odd sizes that we haven't encountered in any LSTM cell, that would
313   // require special code (that would go untested until any LSTM cell
314   // exercises it). We just guard our assumptions about size evenness with
315   // the following assertions.
316   TFLITE_DCHECK(!(output_size % 4));
317   TFLITE_DCHECK(!(input_size % 8));
318   const int32* bias_ptr = bias_data;
319   int16* output_ptr = output_data;
320   for (int out = 0; out < output_size; out += 4) {
321     int32x4_t acc_0 = vdupq_n_s32(0);
322     int32x4_t acc_1 = vdupq_n_s32(0);
323     int32x4_t acc_2 = vdupq_n_s32(0);
324     int32x4_t acc_3 = vdupq_n_s32(0);
325     const int16x8_t input_offset_vec = vdupq_n_s16(-128);
326     const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point);
327     int in = 0;
328     // Handle 16 levels of depth at a time.
329     for (; in <= input_size - 16; in += 16) {
330       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
331       const uint8* weights_ptr = weights_data + in + out * input_size;
332       uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size);
333       uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size);
334       uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size);
335       uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size);
336       int16x8_t input_val_0, input_val_1;
337       const uint8x8_t low = vget_low_u8(input_val_u8);
338       const uint8x8_t high = vget_high_u8(input_val_u8);
339       input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
340       input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
341       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
342       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
343       int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0,
344           weights_val_3_0;
345       int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1,
346           weights_val_3_1;
347       weights_val_0_0 = vaddq_s16(
348           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))),
349           weights_offset_vec);
350       weights_val_0_1 = vaddq_s16(
351           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))),
352           weights_offset_vec);
353       weights_val_1_0 = vaddq_s16(
354           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))),
355           weights_offset_vec);
356       weights_val_1_1 = vaddq_s16(
357           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))),
358           weights_offset_vec);
359       weights_val_2_0 = vaddq_s16(
360           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))),
361           weights_offset_vec);
362       weights_val_2_1 = vaddq_s16(
363           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))),
364           weights_offset_vec);
365       weights_val_3_0 = vaddq_s16(
366           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))),
367           weights_offset_vec);
368       weights_val_3_1 = vaddq_s16(
369           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))),
370           weights_offset_vec);
371       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0),
372                         vget_low_s16(input_val_0));
373       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0),
374                         vget_low_s16(input_val_0));
375       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0),
376                         vget_low_s16(input_val_0));
377       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0),
378                         vget_low_s16(input_val_0));
379       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0),
380                         vget_high_s16(input_val_0));
381       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0),
382                         vget_high_s16(input_val_0));
383       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0),
384                         vget_high_s16(input_val_0));
385       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0),
386                         vget_high_s16(input_val_0));
387       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1),
388                         vget_low_s16(input_val_1));
389       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1),
390                         vget_low_s16(input_val_1));
391       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1),
392                         vget_low_s16(input_val_1));
393       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1),
394                         vget_low_s16(input_val_1));
395       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1),
396                         vget_high_s16(input_val_1));
397       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1),
398                         vget_high_s16(input_val_1));
399       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1),
400                         vget_high_s16(input_val_1));
401       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1),
402                         vget_high_s16(input_val_1));
403     }
404     // Handle 8 levels of depth at a time.
405     for (; in < input_size; in += 8) {
406       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
407       const uint8* weights_ptr = weights_data + in + out * input_size;
408       uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size);
409       uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size);
410       uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size);
411       uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size);
412       int16x8_t input_val;
413       input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
414       input_val = vaddq_s16(input_val, input_offset_vec);
415       int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3;
416       weights_val_0 =
417           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)),
418                     weights_offset_vec);
419       weights_val_1 =
420           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)),
421                     weights_offset_vec);
422       weights_val_2 =
423           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)),
424                     weights_offset_vec);
425       weights_val_3 =
426           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)),
427                     weights_offset_vec);
428       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0),
429                         vget_low_s16(input_val));
430       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1),
431                         vget_low_s16(input_val));
432       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2),
433                         vget_low_s16(input_val));
434       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3),
435                         vget_low_s16(input_val));
436       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0),
437                         vget_high_s16(input_val));
438       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1),
439                         vget_high_s16(input_val));
440       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2),
441                         vget_high_s16(input_val));
442       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3),
443                         vget_high_s16(input_val));
444     }
445     // Horizontally reduce accumulators
446     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
447         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
448     pairwise_reduced_acc_0 =
449         vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
450     pairwise_reduced_acc_1 =
451         vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
452     pairwise_reduced_acc_2 =
453         vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
454     pairwise_reduced_acc_3 =
455         vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
456     const int32x2_t reduced_lo =
457         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
458     const int32x2_t reduced_hi =
459         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
460     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
461     // Add bias values.
462     int32x4_t bias_vec = vld1q_s32(bias_ptr);
463     bias_ptr += 4;
464     reduced = vaddq_s32(reduced, bias_vec);
465     int left_shift = accum_shift > 0 ? accum_shift : 0;
466     int right_shift = accum_shift > 0 ? 0 : -accum_shift;
467     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
468     // Multiply by the fixed-point multiplier.
469     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
470     // Rounding-shift-right.
471     using gemmlowp::RoundingDivideByPOT;
472     reduced = RoundingDivideByPOT(reduced, right_shift);
473     // Narrow values down to 16 bit signed.
474     const int16x4_t res16 = vqmovn_s32(reduced);
475     vst1_s16(output_ptr, res16);
476     output_ptr += 4;
477   }
478 }
479 #endif
480 
481 #ifdef GEMMLOWP_NEON
GEMVForLstmCellWithSymmetricRange(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)482 inline void GEMVForLstmCellWithSymmetricRange(
483     const RuntimeShape& input_shape, const uint8* input_data,
484     const RuntimeShape& weights_shape, const uint8* weights_data,
485     const RuntimeShape& bias_shape, const int32* bias_data,
486     int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
487     int16* output_data) {
488   gemmlowp::ScopedProfilingLabel label("GEMVForLstmCellWithSymmetricRange");
489   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
490   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
491   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
492   const int output_dim_count = output_shape.DimensionsCount();
493   const int weights_dim_count = weights_shape.DimensionsCount();
494   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
495   const int input_size = FlatSizeSkipDim(input_shape, 0);
496   const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
497                                       output_shape, output_dim_count - 1);
498   // This special fast path for quantized LSTM cells does not try to support
499   // odd sizes that we haven't encountered in any LSTM cell, that would
500   // require special code (that would go untested until any LSTM cell
501   // exercises it). We just guard our assumptions about size evenness with
502   // the following assertions.
503   TFLITE_DCHECK(!(output_size % 4));
504   TFLITE_DCHECK(!(input_size % 64));
505   const int32* bias_ptr = bias_data;
506   int16* output_ptr = output_data;
507   const uint8x16_t signbit = vdupq_n_u8(0x80);
508   for (int in = 0; in < input_size; in += 32) {
509     optimized_ops_preload_l1_keep(input_data + in);
510   }
511   const int left_shift = accum_shift > 0 ? accum_shift : 0;
512   const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
513   for (int out = 0; out < output_size; out += 4) {
514     // Load the bias values
515     int32x4_t bias_vec = vld1q_s32(bias_ptr);
516     bias_ptr += 4;
517 
518     // Clear accumulators. We use 2 accumulator registers per row,
519     // for 4 rows. row_accumRN is the N-th accumulator for row R.
520     int32x4_t row_accum00 = vdupq_n_s32(0);
521     int32x4_t row_accum01 = vdupq_n_s32(0);
522     int32x4_t row_accum10 = vdupq_n_s32(0);
523     int32x4_t row_accum11 = vdupq_n_s32(0);
524     int32x4_t row_accum20 = vdupq_n_s32(0);
525     int32x4_t row_accum21 = vdupq_n_s32(0);
526     int32x4_t row_accum30 = vdupq_n_s32(0);
527     int32x4_t row_accum31 = vdupq_n_s32(0);
528 
529     // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
530     const int kReadAhead = 512;
531     // Prefetch the first weights values.
532     for (int k = 0; k < kReadAhead; k += 64) {
533       optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
534                                       k);
535       optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
536                                       k);
537       optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
538                                       k);
539       optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
540                                       k);
541     }
542     // Loop along the rows, handling 64 bytes per iteration because that's
543     // cache line size on most current ARM-architecture CPUs.
544     for (int in = 0; in < input_size; in += 64) {
545       // Prefetch some future weights values.
546       optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
547                                       in + kReadAhead);
548       optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
549                                       in + kReadAhead);
550       optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
551                                       in + kReadAhead);
552       optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
553                                       in + kReadAhead);
554 
555       // We will use 2 local 16-bit accumulators per row, for 2 rows.
556       // See below (*) for the rationale of processing only 2 rows at a time.
557       // local_accumRN is the N-th local accumulator for row R.
558       int16x8_t local_accum00;
559       int16x8_t local_accum01;
560       int16x8_t local_accum10;
561       int16x8_t local_accum11;
562 
563       // Load 64 bytes of input activations values. Convert to signed int8
564       // by flipping the sign bit (i.e. subtracting 128, the required
565       // zero_point value).
566       int8x16_t input0 = vreinterpretq_s8_u8(
567           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
568       int8x16_t input1 = vreinterpretq_s8_u8(
569           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
570       int8x16_t input2 = vreinterpretq_s8_u8(
571           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
572       int8x16_t input3 = vreinterpretq_s8_u8(
573           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
574 
575       // Beginning of the core accumulation. Notice how while we have 4
576       // rows to process, this code is taking care of only 2 rows at a time,
577       // thus being divided into two parts looking similar ("Rows 0 and 1" and
578       // "Rows 2 and 3").
579       //
580       // (*) The rationale for handling only 2 rows at a time is to avoid
581       // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
582       // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
583       // we may find ourselves in a situation where rows alias each other in
584       // the L1 cache, and moreover may also mutually alias with the input
585       // activations. If we try to load 4 rows at a time, together with the
586       // input activations, that may be 5 mutually-aliasing vectors, resulting
587       // in constant mutual eviction from L1 cache. Handling 2 rows at a time
588       // here largely mitigates these issues, and seems at least to be very
589       // effective on Cortex-A53:
590       //                          Before       After
591       // big (Cortex-A73)         2.85 ms      2.85 ms
592       // little (Cortex-A53)      11.0 ms      5.16 ms
593 
594       // Rows 0 and 1:
595       // Load 64 bytes of weights values from each row. Convert to signed int8
596       // by flipping the sign bit (i.e. subtracting 128, the required
597       // zero_point value).
598       int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
599           signbit,
600           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
601       int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
602           signbit,
603           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
604       int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
605           signbit,
606           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
607       int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
608           signbit,
609           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
610       int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
611           signbit,
612           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
613       int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
614           signbit,
615           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
616       int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
617           signbit,
618           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
619       int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
620           signbit,
621           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
622       // Multiply-accumulate into local 16-bit accumulators.
623       // We can accumulate two products without overflow because weights are
624       // required to never be -128, so each product is at most 127^2 in absolute
625       // value.
626       local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
627       local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
628       local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
629       local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
630       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
631                                vget_high_s8(input0));
632       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
633                                vget_high_s8(input1));
634       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
635                                vget_high_s8(input0));
636       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
637                                vget_high_s8(input1));
638       // Pairwise add and accumulate into 32-bit accumulators
639       row_accum00 = vpadalq_s16(row_accum00, local_accum00);
640       row_accum01 = vpadalq_s16(row_accum01, local_accum01);
641       row_accum10 = vpadalq_s16(row_accum10, local_accum10);
642       row_accum11 = vpadalq_s16(row_accum11, local_accum11);
643       // Multiply-accumulate into local 16-bit accumulators.
644       // We can accumulate two products without overflow because weights are
645       // required to never be -128, so each product is at most 127^2 in absolute
646       // value.
647       local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
648       local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
649       local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
650       local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
651       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
652                                vget_high_s8(input2));
653       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
654                                vget_high_s8(input3));
655       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
656                                vget_high_s8(input2));
657       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
658                                vget_high_s8(input3));
659       // Pairwise add and accumulate into 32-bit accumulators
660       row_accum00 = vpadalq_s16(row_accum00, local_accum00);
661       row_accum01 = vpadalq_s16(row_accum01, local_accum01);
662       row_accum10 = vpadalq_s16(row_accum10, local_accum10);
663       row_accum11 = vpadalq_s16(row_accum11, local_accum11);
664 
665       // Rows 2 and 3:
666       // Load 64 bytes of weights values from each row. Convert to signed int8
667       // by flipping the sign bit (i.e. subtracting 128, the required
668       // zero_point value).
669       weights00 = vreinterpretq_s8_u8(veorq_u8(
670           signbit,
671           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
672       weights01 = vreinterpretq_s8_u8(veorq_u8(
673           signbit,
674           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
675       weights02 = vreinterpretq_s8_u8(veorq_u8(
676           signbit,
677           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
678       weights03 = vreinterpretq_s8_u8(veorq_u8(
679           signbit,
680           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
681       weights10 = vreinterpretq_s8_u8(veorq_u8(
682           signbit,
683           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
684       weights11 = vreinterpretq_s8_u8(veorq_u8(
685           signbit,
686           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
687       weights12 = vreinterpretq_s8_u8(veorq_u8(
688           signbit,
689           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
690       weights13 = vreinterpretq_s8_u8(veorq_u8(
691           signbit,
692           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
693       // Multiply-accumulate into local 16-bit accumulators.
694       // We can accumulate two products without overflow because weights are
695       // required to never be -128, so each product is at most 127^2 in absolute
696       // value.
697       local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
698       local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
699       local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
700       local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
701       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
702                                vget_high_s8(input0));
703       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
704                                vget_high_s8(input1));
705       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
706                                vget_high_s8(input0));
707       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
708                                vget_high_s8(input1));
709       // Pairwise add and accumulate into 32-bit accumulators
710       row_accum20 = vpadalq_s16(row_accum20, local_accum00);
711       row_accum21 = vpadalq_s16(row_accum21, local_accum01);
712       row_accum30 = vpadalq_s16(row_accum30, local_accum10);
713       row_accum31 = vpadalq_s16(row_accum31, local_accum11);
714       // Multiply-accumulate into local 16-bit accumulators.
715       // We can accumulate two products without overflow because weights are
716       // required to never be -128, so each product is at most 127^2 in absolute
717       // value.
718       local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
719       local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
720       local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
721       local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
722       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
723                                vget_high_s8(input2));
724       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
725                                vget_high_s8(input3));
726       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
727                                vget_high_s8(input2));
728       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
729                                vget_high_s8(input3));
730       // Pairwise add and accumulate into 32-bit accumulators
731       row_accum20 = vpadalq_s16(row_accum20, local_accum00);
732       row_accum21 = vpadalq_s16(row_accum21, local_accum01);
733       row_accum30 = vpadalq_s16(row_accum30, local_accum10);
734       row_accum31 = vpadalq_s16(row_accum31, local_accum11);
735     }
736 
737     row_accum00 = vaddq_s32(row_accum00, row_accum01);
738     row_accum10 = vaddq_s32(row_accum10, row_accum11);
739     row_accum20 = vaddq_s32(row_accum20, row_accum21);
740     row_accum30 = vaddq_s32(row_accum30, row_accum31);
741     // Horizontally reduce accumulators
742     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
743         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
744     pairwise_reduced_acc_0 =
745         vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
746     pairwise_reduced_acc_1 =
747         vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
748     pairwise_reduced_acc_2 =
749         vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
750     pairwise_reduced_acc_3 =
751         vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
752     const int32x2_t reduced_lo =
753         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
754     const int32x2_t reduced_hi =
755         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
756     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
757     // Add bias values.
758     reduced = vaddq_s32(reduced, bias_vec);
759     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
760     // Multiply by the fixed-point multiplier.
761     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
762     // Rounding-shift-right.
763     using gemmlowp::RoundingDivideByPOT;
764     reduced = RoundingDivideByPOT(reduced, right_shift);
765     // Narrow values down to 16 bit signed.
766     const int16x4_t res16 = vqmovn_s32(reduced);
767     vst1_s16(output_ptr, res16);
768     output_ptr += 4;
769   }
770 }
771 #endif
772 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data)773 inline void FullyConnected(
774     const FullyConnectedParams& params, const RuntimeShape& input_shape,
775     const float* input_data, const RuntimeShape& weights_shape,
776     const float* weights_data, const RuntimeShape& bias_shape,
777     const float* optional_bias_data, const RuntimeShape& output_shape,
778     float* output_data) {
779   gemmlowp::ScopedProfilingLabel label("FullyConnected");
780   const float output_activation_min = params.float_activation_min;
781   const float output_activation_max = params.float_activation_max;
782 
783   // TODO(b/62193649): this convoluted shape computation (determining
784   // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
785   // is because the current --variable_batch hack consists in overwriting the
786   // 3rd dimension with the runtime batch size, as we don't keep track for each
787   // array of which dimension is the batch dimension in it.
788   // When that is fixed, this should become:
789   // const auto input_matrix_map =
790   //     MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
791   const int dims_count = weights_shape.DimensionsCount();
792   const int input_rows = weights_shape.Dims(dims_count - 1);
793   const auto input_matrix_map =
794       MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
795   const auto filter_matrix_map =
796       MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
797   auto output_matrix_map =
798       MapAsMatrixWithLastDimAsRows(output_data, output_shape);
799 
800   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
801 
802   if (optional_bias_data != nullptr) {
803     AddBiasAndEvalActivationFunction(
804         output_activation_min, output_activation_max, bias_shape,
805         optional_bias_data, output_shape, output_data);
806   } else {
807     const int flat_size = output_shape.FlatSize();
808     for (int i = 0; i < flat_size; ++i) {
809       output_data[i] = ActivationFunctionWithMinMax(
810           output_data[i], output_activation_min, output_activation_max);
811     }
812   }
813 }
814 
815 #ifdef USE_NEON
FullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,int row_start,int row_end)816 inline void FullyConnectedAsGEMVWorkerImpl(
817     const RuntimeShape& input_shape, const uint8* input_data,
818     int32 input_offset, const RuntimeShape& filter_shape,
819     const uint8* filter_data, int32 filter_offset,
820     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
821     int32 output_multiplier, int output_shift, int32 output_activation_min,
822     int32 output_activation_max, const RuntimeShape& output_shape,
823     uint8* output_data, int row_start, int row_end) {
824   gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
825   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
826   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
827   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
828   const int output_dim_count = output_shape.DimensionsCount();
829   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
830   const int input_size = FlatSizeSkipDim(input_shape, 0);
831   static constexpr int kPeel = 4;
832   const bool shift_left = (output_shift > 0);
833   for (int k = 0; k < input_size; k += 64) {
834     optimized_ops_preload_l1_stream(input_data + k);
835   }
836   for (int k = 0; k < kPeel * input_size; k += 64) {
837     optimized_ops_preload_l1_stream(filter_data + k);
838   }
839 
840   TFLITE_DCHECK_GE(row_end - row_start, kPeel);
841 
842   for (int out = row_start; out < row_end; out += kPeel) {
843     out = std::min(out, row_end - kPeel);
844     int32x4_t acc0 = vdupq_n_s32(0);
845     int32x4_t acc1 = acc0;
846     int32x4_t acc2 = acc0;
847     int32x4_t acc3 = acc0;
848     const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
849     const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
850     int in = 0;
851     for (; in <= input_size - 16; in += 16) {
852       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
853       const uint8* filter_ptr = filter_data + in + out * input_size;
854       uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr);
855       optimized_ops_preload_l1_stream(filter_ptr + 64);
856       filter_ptr += input_size;
857       uint8x16_t filter_val_u8_1 = vld1q_u8(filter_ptr);
858       optimized_ops_preload_l1_stream(filter_ptr + 64);
859       filter_ptr += input_size;
860       uint8x16_t filter_val_u8_2 = vld1q_u8(filter_ptr);
861       optimized_ops_preload_l1_stream(filter_ptr + 64);
862       filter_ptr += input_size;
863       uint8x16_t filter_val_u8_3 = vld1q_u8(filter_ptr);
864       optimized_ops_preload_l1_stream(filter_ptr + 64);
865       int16x8_t input_val_0, input_val_1;
866       uint8x8_t low = vget_low_u8(input_val_u8);
867       uint8x8_t high = vget_high_u8(input_val_u8);
868       input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
869       input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
870       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
871       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
872       low = vget_low_u8(filter_val_u8_0);
873       high = vget_high_u8(filter_val_u8_0);
874       int16x8_t filter_val_0_0 = vreinterpretq_s16_u16(vmovl_u8(low));
875       int16x8_t filter_val_0_1 = vreinterpretq_s16_u16(vmovl_u8(high));
876       filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
877       filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
878       low = vget_low_u8(filter_val_u8_1);
879       high = vget_high_u8(filter_val_u8_1);
880       int16x8_t filter_val_1_0 = vreinterpretq_s16_u16(vmovl_u8(low));
881       int16x8_t filter_val_1_1 = vreinterpretq_s16_u16(vmovl_u8(high));
882       filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
883       filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
884       low = vget_low_u8(filter_val_u8_2);
885       high = vget_high_u8(filter_val_u8_2);
886       int16x8_t filter_val_2_0 = vreinterpretq_s16_u16(vmovl_u8(low));
887       int16x8_t filter_val_2_1 = vreinterpretq_s16_u16(vmovl_u8(high));
888       filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
889       filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
890       low = vget_low_u8(filter_val_u8_3);
891       high = vget_high_u8(filter_val_u8_3);
892       int16x8_t filter_val_3_0 = vreinterpretq_s16_u16(vmovl_u8(low));
893       int16x8_t filter_val_3_1 = vreinterpretq_s16_u16(vmovl_u8(high));
894       filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
895       filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
896       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
897                        vget_low_s16(input_val_0));
898       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
899                        vget_low_s16(input_val_0));
900       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
901                        vget_low_s16(input_val_0));
902       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
903                        vget_low_s16(input_val_0));
904       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
905                        vget_low_s16(input_val_1));
906       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
907                        vget_low_s16(input_val_1));
908       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
909                        vget_low_s16(input_val_1));
910       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
911                        vget_low_s16(input_val_1));
912       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
913                        vget_high_s16(input_val_0));
914       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
915                        vget_high_s16(input_val_0));
916       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
917                        vget_high_s16(input_val_0));
918       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
919                        vget_high_s16(input_val_0));
920       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
921                        vget_high_s16(input_val_1));
922       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
923                        vget_high_s16(input_val_1));
924       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
925                        vget_high_s16(input_val_1));
926       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
927                        vget_high_s16(input_val_1));
928     }
929     for (; in <= input_size - 8; in += 8) {
930       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
931       const uint8* filter_ptr = filter_data + in + out * input_size;
932       uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr);
933       filter_ptr += input_size;
934       uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr);
935       filter_ptr += input_size;
936       uint8x8_t filter_val_u8_2 = vld1_u8(filter_ptr);
937       filter_ptr += input_size;
938       uint8x8_t filter_val_u8_3 = vld1_u8(filter_ptr);
939       int16x8_t input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
940       input_val = vaddq_s16(input_val, input_offset_vec);
941       int16x8_t filter_val_0 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_0));
942       filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
943       int16x8_t filter_val_1 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_1));
944       filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
945       int16x8_t filter_val_2 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_2));
946       filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
947       int16x8_t filter_val_3 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_3));
948       filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
949       acc0 =
950           vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
951       acc1 =
952           vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
953       acc2 =
954           vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
955       acc3 =
956           vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
957       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
958                        vget_high_s16(input_val));
959       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
960                        vget_high_s16(input_val));
961       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
962                        vget_high_s16(input_val));
963       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
964                        vget_high_s16(input_val));
965     }
966     if (in < input_size) {
967       int32 buf[16];
968       vst1q_s32(buf + 0, acc0);
969       vst1q_s32(buf + 4, acc1);
970       vst1q_s32(buf + 8, acc2);
971       vst1q_s32(buf + 12, acc3);
972       for (; in < input_size; in++) {
973         int lane = (in + 8 - input_size) % 4;
974         const int32 input_val = input_data[in] + input_offset;
975         for (int k = 0; k < kPeel; k++) {
976           int32 filter_val =
977               filter_data[in + (out + k) * input_size] + filter_offset;
978           buf[lane + 4 * k] += filter_val * input_val;
979         }
980       }
981       acc0 = vld1q_s32(buf + 0);
982       acc1 = vld1q_s32(buf + 4);
983       acc2 = vld1q_s32(buf + 8);
984       acc3 = vld1q_s32(buf + 12);
985     }
986 
987     // Horizontally reduce accumulators
988     int32x2_t pairwise_reduced_acc_0 =
989         vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
990     int32x2_t pairwise_reduced_acc_1 =
991         vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
992     int32x2_t pairwise_reduced_acc_2 =
993         vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
994     int32x2_t pairwise_reduced_acc_3 =
995         vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
996     const int32x2_t reduced_lo =
997         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
998     const int32x2_t reduced_hi =
999         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1000     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1001     // Add bias values.
1002     int32x4_t bias_vec = vld1q_s32(bias_data + out);
1003     reduced = vaddq_s32(reduced, bias_vec);
1004     if (shift_left) {
1005       const int32 multiplier_power_of_two = 1 << output_shift;
1006       reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
1007       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1008     } else {
1009       // Multiply by the fixed-point multiplier.
1010       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1011       // Rounding-shift-right.
1012       using gemmlowp::RoundingDivideByPOT;
1013       reduced = RoundingDivideByPOT(reduced, -output_shift);
1014     }
1015     // Add the output offset.
1016     const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1017     reduced = vaddq_s32(reduced, output_offset_vec);
1018     // Narrow values down to 16 bit signed.
1019     const int16x4_t res16 = vqmovn_s32(reduced);
1020     // Narrow values down to 8 bit unsigned, saturating.
1021     uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
1022     // Apply the clamping from the activation function
1023     res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
1024     res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
1025     // Store results to destination.
1026     vst1_lane_u8(output_data + out + 0, res8, 0);
1027     vst1_lane_u8(output_data + out + 1, res8, 1);
1028     vst1_lane_u8(output_data + out + 2, res8, 2);
1029     vst1_lane_u8(output_data + out + 3, res8, 3);
1030   }
1031 }
1032 
1033 struct FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
FullyConnectedAsGEMVWorkerTaskFullyConnectedAsGEMVWorkerTask1034   FullyConnectedAsGEMVWorkerTask(const RuntimeShape& input_shape,
1035                                  const uint8* input_data, int32 input_offset,
1036                                  const RuntimeShape& filter_shape,
1037                                  const uint8* filter_data, int32 filter_offset,
1038                                  const RuntimeShape& bias_shape,
1039                                  const int32* bias_data, int32 output_offset,
1040                                  int32 output_multiplier, int output_shift,
1041                                  int32 output_activation_min,
1042                                  int32 output_activation_max,
1043                                  const RuntimeShape& output_shape,
1044                                  uint8* output_data, int row_start, int row_end)
1045       : input_shape_(input_shape),
1046         input_data_(input_data),
1047         input_offset_(input_offset),
1048         filter_shape_(filter_shape),
1049         filter_data_(filter_data),
1050         filter_offset_(filter_offset),
1051         bias_shape_(bias_shape),
1052         bias_data_(bias_data),
1053         output_offset_(output_offset),
1054         output_multiplier_(output_multiplier),
1055         output_shift_(output_shift),
1056         output_activation_min_(output_activation_min),
1057         output_activation_max_(output_activation_max),
1058         output_shape_(output_shape),
1059         output_data_(output_data),
1060         row_start_(row_start),
1061         row_end_(row_end) {}
1062 
RunFullyConnectedAsGEMVWorkerTask1063   void Run() override {
1064     FullyConnectedAsGEMVWorkerImpl(
1065         input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
1066         filter_offset_, bias_shape_, bias_data_, output_offset_,
1067         output_multiplier_, output_shift_, output_activation_min_,
1068         output_activation_max_, output_shape_, output_data_, row_start_,
1069         row_end_);
1070   }
1071 
1072   const RuntimeShape& input_shape_;
1073   const uint8* input_data_;
1074   int32 input_offset_;
1075   const RuntimeShape& filter_shape_;
1076   const uint8* filter_data_;
1077   int32 filter_offset_;
1078   const RuntimeShape& bias_shape_;
1079   const int32* bias_data_;
1080   int32 output_offset_;
1081   int32 output_multiplier_;
1082   int output_shift_;
1083   int32 output_activation_min_;
1084   int32 output_activation_max_;
1085   const RuntimeShape& output_shape_;
1086   uint8* output_data_;
1087   gemmlowp::GemmContext* gemm_context_;
1088   int row_start_;
1089   int row_end_;
1090 };
1091 
FullyConnectedAsGEMV(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemm_context)1092 inline void FullyConnectedAsGEMV(
1093     const RuntimeShape& input_shape, const uint8* input_data,
1094     int32 input_offset, const RuntimeShape& filter_shape,
1095     const uint8* filter_data, int32 filter_offset,
1096     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1097     int32 output_multiplier, int output_shift, int32 output_activation_min,
1098     int32 output_activation_max, const RuntimeShape& output_shape,
1099     uint8* output_data, gemmlowp::GemmContext* gemm_context) {
1100   const int output_dim_count = output_shape.DimensionsCount();
1101   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1102   const int output_rows = output_shape.Dims(output_dim_count - 1);
1103   const int input_size = FlatSizeSkipDim(input_shape, 0);
1104   static constexpr int kKernelRows = 4;
1105   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1106       gemm_context->max_num_threads(), output_rows, batches, input_size);
1107   if (thread_count == 1) {
1108     // Single-thread case: do the computation on the current thread, don't
1109     // use a threadpool
1110     FullyConnectedAsGEMVWorkerImpl(
1111         input_shape, input_data, input_offset, filter_shape, filter_data,
1112         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1113         output_shift, output_activation_min, output_activation_max,
1114         output_shape, output_data, 0, output_rows);
1115     return;
1116   }
1117 
1118   // Multi-threaded case: use the gemmlowp context's threadpool.
1119   TFLITE_DCHECK_GT(thread_count, 1);
1120   std::vector<gemmlowp::Task*> tasks(thread_count);
1121   const int kRowsPerWorker =
1122       gemmlowp::RoundUp<kKernelRows>(output_rows / thread_count);
1123   int row_start = 0;
1124   for (int i = 0; i < thread_count; ++i) {
1125     int row_end = std::min(output_rows, row_start + kRowsPerWorker);
1126     tasks[i] = new FullyConnectedAsGEMVWorkerTask(
1127         input_shape, input_data, input_offset, filter_shape, filter_data,
1128         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1129         output_shift, output_activation_min, output_activation_max,
1130         output_shape, output_data, row_start, row_end);
1131     row_start = row_end;
1132   }
1133   TFLITE_DCHECK_EQ(row_start, output_rows);
1134   gemm_context->workers_pool()->Execute(tasks);
1135 }
1136 #endif  // USE_NEON
1137 
1138 struct GemmlowpOutputPipeline {
1139   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1140       ColVectorMap;
1141   typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
1142                      gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
1143                      gemmlowp::OutputStageClamp,
1144                      gemmlowp::OutputStageSaturatingCastToUint8>
1145       Pipeline;
MakeExpGemmlowpOutputPipeline1146   static Pipeline MakeExp(const int32* bias_data, int output_rows,
1147                           int32 output_offset, int32 output_multiplier,
1148                           int output_left_shift, int32 output_activation_min,
1149                           int32 output_activation_max) {
1150     ColVectorMap bias_vector(bias_data, output_rows);
1151     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1152     bias_addition_stage.bias_vector = bias_vector;
1153     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
1154     quantize_down_stage.result_offset_after_shift = output_offset;
1155     quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
1156     quantize_down_stage.result_exponent = output_left_shift;
1157     gemmlowp::OutputStageClamp clamp_stage;
1158     clamp_stage.min = output_activation_min;
1159     clamp_stage.max = output_activation_max;
1160     gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
1161     return std::make_tuple(bias_addition_stage, quantize_down_stage,
1162                            clamp_stage, saturating_cast_stage);
1163   }
1164 };
1165 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemm_context)1166 inline void FullyConnected(
1167     const FullyConnectedParams& params, const RuntimeShape& input_shape,
1168     const uint8* input_data, const RuntimeShape& filter_shape,
1169     const uint8* filter_data, const RuntimeShape& bias_shape,
1170     const int32* bias_data, const RuntimeShape& output_shape,
1171     uint8* output_data, gemmlowp::GemmContext* gemm_context) {
1172   gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
1173   const int32 input_offset = params.input_offset;
1174   const int32 filter_offset = params.weights_offset;
1175   const int32 output_offset = params.output_offset;
1176   const int32 output_multiplier = params.output_multiplier;
1177   const int output_shift = params.output_shift;
1178   const int32 output_activation_min = params.quantized_activation_min;
1179   const int32 output_activation_max = params.quantized_activation_max;
1180   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1181   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1182   // TODO(benoitjacob): This really should be:
1183   //     const int batches = ArraySize(output_dims, 1);
1184   // but the current --variable_batch hack consists in overwriting the 3rd
1185   // dimension with the runtime batch size, as we don't keep track for each
1186   // array of which dimension is the batch dimension in it.
1187   const int output_dim_count = output_shape.DimensionsCount();
1188   const int filter_dim_count = filter_shape.DimensionsCount();
1189   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1190 #ifdef USE_NEON
1191   if (batches == 1) {
1192     const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
1193                                         output_shape, output_dim_count - 1);
1194     if (output_size >= 4) {
1195       return FullyConnectedAsGEMV(
1196           input_shape, input_data, input_offset, filter_shape, filter_data,
1197           filter_offset, bias_shape, bias_data, output_offset,
1198           output_multiplier, output_shift, output_activation_min,
1199           output_activation_max, output_shape, output_data, gemm_context);
1200     }
1201   }
1202 #endif  // USE_NEON
1203   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
1204   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
1205   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
1206   const int output_rows = output_shape.Dims(output_dim_count - 1);
1207   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1208   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1209 
1210   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
1211       filter_data, output_rows, filter_cols, filter_cols);
1212   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1213       input_data, filter_cols, batches, filter_cols);
1214   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
1215       output_data, output_rows, batches, output_rows);
1216   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
1217       bias_data, output_rows, output_offset, output_multiplier, output_shift,
1218       output_activation_min, output_activation_max);
1219   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
1220                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1221       gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
1222       input_offset, output_pipeline);
1223 }
1224 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext * gemm_context)1225 inline void FullyConnected(
1226     const FullyConnectedParams& params, const RuntimeShape& input_shape,
1227     const uint8* input_data, const RuntimeShape& filter_shape,
1228     const uint8* filter_data, const RuntimeShape& bias_shape,
1229     const int32* bias_data_int32, const RuntimeShape& output_shape,
1230     int16* output_data, gemmlowp::GemmContext* gemm_context) {
1231   gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
1232   const int32 input_offset = params.input_offset;
1233   const int32 filter_offset = params.weights_offset;
1234   const int32 output_offset = params.output_offset;
1235   const int32 output_multiplier = params.output_multiplier;
1236   const int output_shift = params.output_shift;
1237   const int32 output_activation_min = params.quantized_activation_min;
1238   const int32 output_activation_max = params.quantized_activation_max;
1239   // This is a copy of the reference implementation. We do not currently have a
1240   // properly optimized version.
1241   (void)gemm_context;  // only used in properly optimized code.
1242   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1243   TFLITE_DCHECK_EQ(output_offset, 0);
1244   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1245   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1246 
1247   // TODO(benoitjacob): This really should be:
1248   //     const int batches = ArraySize(output_dims, 1);
1249   // but the current --variable_batch hack consists in overwriting the 3rd
1250   // dimension with the runtime batch size, as we don't keep track for each
1251   // array of which dimension is the batch dimension in it.
1252   const int output_dim_count = output_shape.DimensionsCount();
1253   const int filter_dim_count = filter_shape.DimensionsCount();
1254   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1255   const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
1256                                        output_shape, output_dim_count - 1);
1257   const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
1258 
1259   // Implementation of the fully connected node suited to the inside of an LSTM
1260   // cell. The operands are 8-bit integers, the accumulators are internally
1261   // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so
1262   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
1263   // is explained in the function comment above.
1264 #ifdef GEMMLOWP_NEON
1265   if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
1266       output_activation_max == 32767) {
1267     if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
1268       GEMVForLstmCellWithSymmetricRange(
1269           input_shape, input_data, filter_shape, filter_data, bias_shape,
1270           bias_data_int32, output_multiplier, output_shift, output_shape,
1271           output_data);
1272       return;
1273     }
1274     if (!(output_depth % 4) && !(accum_depth % 8)) {
1275       GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
1276                       filter_offset, bias_shape, bias_data_int32,
1277                       output_multiplier, output_shift, output_shape,
1278                       output_data);
1279       return;
1280     }
1281   }
1282 #endif
1283   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> weights_matrix(
1284       filter_data, output_depth, accum_depth);
1285   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1286       input_data, accum_depth, batches);
1287   gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
1288       output_data, output_depth, batches);
1289   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1290       ColVectorMap;
1291   ColVectorMap bias_vector(bias_data_int32, output_depth);
1292   gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1293   bias_addition_stage.bias_vector = bias_vector;
1294   gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
1295   scale_stage.result_offset_after_shift = 0;
1296   scale_stage.result_fixedpoint_multiplier = output_multiplier;
1297   // Note that this shift is negated wrt ordinary FC.
1298   scale_stage.result_exponent = output_shift;
1299   gemmlowp::OutputStageClamp clamp_stage;
1300   clamp_stage.min = output_activation_min;
1301   clamp_stage.max = output_activation_max;
1302   gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1303   auto output_pipeline =
1304       std::make_tuple(bias_addition_stage, scale_stage, clamp_stage,
1305                       saturating_cast_int16_stage);
1306   gemmlowp::GemmWithOutputPipeline<uint8, int16,
1307                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1308       gemm_context, weights_matrix, input_matrix, &output_matrix, filter_offset,
1309       input_offset, output_pipeline);
1310 }
1311 
1312 // Internal function doing the actual arithmetic work for
1313 // ShuffledFullyConnected.
1314 // May be called either directly by it (single-threaded case) or may be used
1315 // as the 'task' for worker threads to run (multi-threaded case, see
1316 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)1317 inline void ShuffledFullyConnectedWorkerImpl(
1318     const uint8* shuffled_input_workspace_data,
1319     const int8* shuffled_weights_data, int batches, int output_depth,
1320     int output_stride, int accum_depth, const int32* bias_data,
1321     int32 output_multiplier, int output_shift, int16* output_data) {
1322 #if defined USE_NEON
1323   const int8* shuffled_weights_ptr = shuffled_weights_data;
1324   if (batches == 1) {
1325     const int right_shift = output_shift > 0 ? 0 : -output_shift;
1326     const int left_shift = output_shift > 0 ? output_shift : 0;
1327     for (int c = 0; c < output_depth; c += 4) {
1328       // Accumulation loop.
1329       int32x4_t row_accum0 = vdupq_n_s32(0);
1330       int32x4_t row_accum1 = vdupq_n_s32(0);
1331       int32x4_t row_accum2 = vdupq_n_s32(0);
1332       int32x4_t row_accum3 = vdupq_n_s32(0);
1333       for (int d = 0; d < accum_depth; d += 16) {
1334         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
1335         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
1336         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
1337         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
1338         shuffled_weights_ptr += 64;
1339         int8x16_t input =
1340             vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
1341         int16x8_t local_accum0 =
1342             vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
1343         int16x8_t local_accum1 =
1344             vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
1345         int16x8_t local_accum2 =
1346             vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
1347         int16x8_t local_accum3 =
1348             vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
1349         local_accum0 =
1350             vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
1351         local_accum1 =
1352             vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
1353         local_accum2 =
1354             vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
1355         local_accum3 =
1356             vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
1357         row_accum0 = vpadalq_s16(row_accum0, local_accum0);
1358         row_accum1 = vpadalq_s16(row_accum1, local_accum1);
1359         row_accum2 = vpadalq_s16(row_accum2, local_accum2);
1360         row_accum3 = vpadalq_s16(row_accum3, local_accum3);
1361       }
1362       // Horizontally reduce accumulators
1363       int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1364           pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1365       pairwise_reduced_acc_0 =
1366           vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
1367       pairwise_reduced_acc_1 =
1368           vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
1369       pairwise_reduced_acc_2 =
1370           vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
1371       pairwise_reduced_acc_3 =
1372           vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
1373       const int32x2_t reduced_lo =
1374           vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1375       const int32x2_t reduced_hi =
1376           vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1377       int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1378       // Add bias values.
1379       int32x4_t bias_vec = vld1q_s32(bias_data + c);
1380       reduced = vaddq_s32(reduced, bias_vec);
1381       reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1382       // Multiply by the fixed-point multiplier.
1383       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1384       // Rounding-shift-right.
1385       using gemmlowp::RoundingDivideByPOT;
1386       reduced = RoundingDivideByPOT(reduced, right_shift);
1387       // Narrow values down to 16 bit signed.
1388       const int16x4_t res16 = vqmovn_s32(reduced);
1389       vst1_s16(output_data + c, res16);
1390     }
1391   } else if (batches == 4) {
1392     const int right_shift = output_shift > 0 ? 0 : -output_shift;
1393     const int left_shift = output_shift > 0 ? output_shift : 0;
1394     for (int c = 0; c < output_depth; c += 4) {
1395       const int8* shuffled_input_ptr =
1396           reinterpret_cast<const int8*>(shuffled_input_workspace_data);
1397       // Accumulation loop.
1398       int32x4_t row_accum00 = vdupq_n_s32(0);
1399       int32x4_t row_accum10 = vdupq_n_s32(0);
1400       int32x4_t row_accum20 = vdupq_n_s32(0);
1401       int32x4_t row_accum30 = vdupq_n_s32(0);
1402       int32x4_t row_accum01 = vdupq_n_s32(0);
1403       int32x4_t row_accum11 = vdupq_n_s32(0);
1404       int32x4_t row_accum21 = vdupq_n_s32(0);
1405       int32x4_t row_accum31 = vdupq_n_s32(0);
1406       int32x4_t row_accum02 = vdupq_n_s32(0);
1407       int32x4_t row_accum12 = vdupq_n_s32(0);
1408       int32x4_t row_accum22 = vdupq_n_s32(0);
1409       int32x4_t row_accum32 = vdupq_n_s32(0);
1410       int32x4_t row_accum03 = vdupq_n_s32(0);
1411       int32x4_t row_accum13 = vdupq_n_s32(0);
1412       int32x4_t row_accum23 = vdupq_n_s32(0);
1413       int32x4_t row_accum33 = vdupq_n_s32(0);
1414       for (int d = 0; d < accum_depth; d += 16) {
1415         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
1416         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
1417         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
1418         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
1419         shuffled_weights_ptr += 64;
1420         int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
1421         int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
1422         int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
1423         int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
1424         shuffled_input_ptr += 64;
1425         int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
1426 #define TFLITE_SHUFFLED_FC_ACCUM(B)                                           \
1427   local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B));      \
1428   local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B));      \
1429   local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B));      \
1430   local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B));      \
1431   local_accum0 =                                                              \
1432       vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
1433   local_accum1 =                                                              \
1434       vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
1435   local_accum2 =                                                              \
1436       vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
1437   local_accum3 =                                                              \
1438       vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
1439   row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0);                   \
1440   row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1);                   \
1441   row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2);                   \
1442   row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
1443 
1444         TFLITE_SHUFFLED_FC_ACCUM(0)
1445         TFLITE_SHUFFLED_FC_ACCUM(1)
1446         TFLITE_SHUFFLED_FC_ACCUM(2)
1447         TFLITE_SHUFFLED_FC_ACCUM(3)
1448 
1449 #undef TFLITE_SHUFFLED_FC_ACCUM
1450       }
1451       // Horizontally reduce accumulators
1452 
1453 #define TFLITE_SHUFFLED_FC_STORE(B)                                           \
1454   {                                                                           \
1455     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,                 \
1456         pairwise_reduced_acc_2, pairwise_reduced_acc_3;                       \
1457     pairwise_reduced_acc_0 =                                                  \
1458         vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
1459     pairwise_reduced_acc_1 =                                                  \
1460         vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
1461     pairwise_reduced_acc_2 =                                                  \
1462         vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
1463     pairwise_reduced_acc_3 =                                                  \
1464         vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
1465     const int32x2_t reduced_lo =                                              \
1466         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);            \
1467     const int32x2_t reduced_hi =                                              \
1468         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);            \
1469     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);                 \
1470     int32x4_t bias_vec = vld1q_s32(bias_data + c);                            \
1471     reduced = vaddq_s32(reduced, bias_vec);                                   \
1472     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));                    \
1473     reduced = vqrdmulhq_n_s32(reduced, output_multiplier);                    \
1474     using gemmlowp::RoundingDivideByPOT;                                      \
1475     reduced = RoundingDivideByPOT(reduced, right_shift);                      \
1476     const int16x4_t res16 = vqmovn_s32(reduced);                              \
1477     vst1_s16(output_data + c + B * output_stride, res16);                     \
1478   }
1479 
1480       TFLITE_SHUFFLED_FC_STORE(0);
1481       TFLITE_SHUFFLED_FC_STORE(1);
1482       TFLITE_SHUFFLED_FC_STORE(2);
1483       TFLITE_SHUFFLED_FC_STORE(3);
1484 
1485 #undef TFLITE_SHUFFLED_FC_STORE
1486     }
1487   } else {
1488     TFLITE_DCHECK(false);
1489     return;
1490   }
1491 #else
1492   if (batches == 1) {
1493     int16* output_ptr = output_data;
1494     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
1495     // so that just reinterpreting them as int8 values is equivalent to
1496     // subtracting 128 from them, thus implementing for free the subtraction of
1497     // the zero_point value 128.
1498     const int8* shuffled_weights_ptr =
1499         reinterpret_cast<const int8*>(shuffled_weights_data);
1500     // Likewise, we preshuffled and pre-xored the input data above.
1501     const int8* shuffled_input_data =
1502         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
1503     for (int c = 0; c < output_depth; c += 4) {
1504       // Internal accumulation.
1505       // Initialize accumulator with the bias-value.
1506       int32 accum[4] = {0};
1507       // Accumulation loop.
1508       for (int d = 0; d < accum_depth; d += 16) {
1509         for (int i = 0; i < 4; i++) {
1510           for (int j = 0; j < 16; j++) {
1511             int8 input_val = shuffled_input_data[d + j];
1512             int8 weights_val = *shuffled_weights_ptr++;
1513             accum[i] += weights_val * input_val;
1514           }
1515         }
1516       }
1517       for (int i = 0; i < 4; i++) {
1518         // Add bias value
1519         int acc = accum[i] + bias_data[c + i];
1520         // Down-scale the final int32 accumulator to the scale used by our
1521         // (16-bit, typically 3 integer bits) fixed-point format. The quantized
1522         // multiplier and shift here have been pre-computed offline
1523         // (e.g. by toco).
1524         acc =
1525             MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
1526         // Saturate, cast to int16, and store to output array.
1527         acc = std::max(acc, -32768);
1528         acc = std::min(acc, 32767);
1529         output_ptr[c + i] = acc;
1530       }
1531     }
1532   } else if (batches == 4) {
1533     int16* output_ptr = output_data;
1534     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
1535     // so that just reinterpreting them as int8 values is equivalent to
1536     // subtracting 128 from them, thus implementing for free the subtraction of
1537     // the zero_point value 128.
1538     const int8* shuffled_weights_ptr =
1539         reinterpret_cast<const int8*>(shuffled_weights_data);
1540     // Likewise, we preshuffled and pre-xored the input data above.
1541     const int8* shuffled_input_data =
1542         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
1543     for (int c = 0; c < output_depth; c += 4) {
1544       const int8* shuffled_input_ptr = shuffled_input_data;
1545       // Accumulation loop.
1546       // Internal accumulation.
1547       // Initialize accumulator with the bias-value.
1548       int32 accum[4][4];
1549       for (int i = 0; i < 4; i++) {
1550         for (int b = 0; b < 4; b++) {
1551           accum[i][b] = 0;
1552         }
1553       }
1554       for (int d = 0; d < accum_depth; d += 16) {
1555         for (int i = 0; i < 4; i++) {
1556           for (int b = 0; b < 4; b++) {
1557             for (int j = 0; j < 16; j++) {
1558               int8 input_val = shuffled_input_ptr[16 * b + j];
1559               int8 weights_val = shuffled_weights_ptr[16 * i + j];
1560               accum[i][b] += weights_val * input_val;
1561             }
1562           }
1563         }
1564         shuffled_input_ptr += 64;
1565         shuffled_weights_ptr += 64;
1566       }
1567       for (int i = 0; i < 4; i++) {
1568         for (int b = 0; b < 4; b++) {
1569           // Add bias value
1570           int acc = accum[i][b] + bias_data[c + i];
1571           // Down-scale the final int32 accumulator to the scale used by our
1572           // (16-bit, typically 3 integer bits) fixed-point format. The
1573           // quantized multiplier and shift here have been pre-computed offline
1574           // (e.g. by toco).
1575           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
1576                                               output_shift);
1577           // Saturate, cast to int16, and store to output array.
1578           acc = std::max(acc, -32768);
1579           acc = std::min(acc, 32767);
1580           output_ptr[b * output_stride + c + i] = acc;
1581         }
1582       }
1583     }
1584   } else {
1585     TFLITE_DCHECK(false);
1586     return;
1587   }
1588 #endif
1589 }
1590 
1591 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
1592 // to allow using gemmlowp's threadpool.
1593 struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask1594   ShuffledFullyConnectedWorkerTask(const uint8* input_data,
1595                                    const int8* shuffled_weights_data,
1596                                    int batches, int output_depth,
1597                                    int output_stride, int accum_depth,
1598                                    const int32* bias_data,
1599                                    int32 output_multiplier, int output_shift,
1600                                    int16* output_data)
1601       : input_data_(input_data),
1602         shuffled_weights_data_(shuffled_weights_data),
1603         batches_(batches),
1604         output_depth_(output_depth),
1605         output_stride_(output_stride),
1606         accum_depth_(accum_depth),
1607         bias_data_(bias_data),
1608         output_multiplier_(output_multiplier),
1609         output_shift_(output_shift),
1610         output_data_(output_data) {}
1611 
RunShuffledFullyConnectedWorkerTask1612   void Run() override {
1613     ShuffledFullyConnectedWorkerImpl(
1614         input_data_, shuffled_weights_data_, batches_, output_depth_,
1615         output_stride_, accum_depth_, bias_data_, output_multiplier_,
1616         output_shift_, output_data_);
1617   }
1618 
1619   const uint8* input_data_;
1620   const int8* shuffled_weights_data_;
1621   int batches_;
1622   int output_depth_;
1623   int output_stride_;
1624   int accum_depth_;
1625   const int32* bias_data_;
1626   int32 output_multiplier_;
1627   int output_shift_;
1628   int16* output_data_;
1629 };
1630 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemm_context)1631 inline void ShuffledFullyConnected(
1632     const FullyConnectedParams& params, const RuntimeShape& input_shape,
1633     const uint8* input_data, const RuntimeShape& weights_shape,
1634     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
1635     const int32* bias_data, const RuntimeShape& output_shape,
1636     int16* output_data, uint8* shuffled_input_workspace_data,
1637     gemmlowp::GemmContext* gemm_context) {
1638   gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
1639   const int32 output_multiplier = params.output_multiplier;
1640   const int output_shift = params.output_shift;
1641   const int32 output_activation_min = params.quantized_activation_min;
1642   const int32 output_activation_max = params.quantized_activation_max;
1643   (void)gemm_context;  // only used in optimized code.
1644   TFLITE_DCHECK_EQ(output_activation_min, -32768);
1645   TFLITE_DCHECK_EQ(output_activation_max, 32767);
1646   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1647   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1648   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1649   // TODO(benoitjacob): This really should be:
1650   //     const int batches = ArraySize(output_dims, 1);
1651   // but the current --variable_batch hack consists in overwriting the 3rd
1652   // dimension with the runtime batch size, as we don't keep track for each
1653   // array of which dimension is the batch dimension in it.
1654   const int output_dim_count = output_shape.DimensionsCount();
1655   const int weights_dim_count = weights_shape.DimensionsCount();
1656   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1657   const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
1658                                        output_shape, output_dim_count - 1);
1659   const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
1660   TFLITE_DCHECK((accum_depth % 16) == 0);
1661   TFLITE_DCHECK((output_depth % 4) == 0);
1662   // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
1663   // so that just reinterpreting them as int8 values is equivalent to
1664   // subtracting 128 from them, thus implementing for free the subtraction of
1665   // the zero_point value 128.
1666   const int8* int8_shuffled_weights_data =
1667       reinterpret_cast<const int8*>(shuffled_weights_data);
1668 
1669   // Shuffling and xoring of input activations into the workspace buffer
1670   if (batches == 1) {
1671 #ifdef USE_NEON
1672     const uint8x16_t signbit = vdupq_n_u8(0x80);
1673     for (int i = 0; i < accum_depth; i += 16) {
1674       uint8x16_t val = vld1q_u8(input_data + i);
1675       val = veorq_u8(val, signbit);
1676       vst1q_u8(shuffled_input_workspace_data + i, val);
1677     }
1678 #else
1679     for (int i = 0; i < accum_depth; i++) {
1680       shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
1681     }
1682 #endif
1683   } else if (batches == 4) {
1684     uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
1685     int c = 0;
1686 #ifdef USE_NEON
1687     const uint8x16_t signbit = vdupq_n_u8(0x80);
1688     for (c = 0; c < accum_depth; c += 16) {
1689       const uint8* src_data_ptr = input_data + c;
1690       uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
1691       uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
1692       uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
1693       uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
1694       val0 = veorq_u8(val0, signbit);
1695       val1 = veorq_u8(val1, signbit);
1696       val2 = veorq_u8(val2, signbit);
1697       val3 = veorq_u8(val3, signbit);
1698       vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
1699       vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
1700       vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
1701       vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
1702       shuffled_input_workspace_ptr += 64;
1703     }
1704 #else
1705     for (c = 0; c < accum_depth; c += 16) {
1706       for (int b = 0; b < 4; b++) {
1707         const uint8* src_data_ptr = input_data + b * accum_depth + c;
1708         for (int j = 0; j < 16; j++) {
1709           uint8 src_val = *src_data_ptr++;
1710           // Flip the sign bit, so that the kernel will only need to
1711           // reinterpret these uint8 values as int8, getting for free the
1712           // subtraction of the zero_point value 128.
1713           uint8 dst_val = src_val ^ 0x80;
1714           *shuffled_input_workspace_ptr++ = dst_val;
1715         }
1716       }
1717     }
1718 #endif
1719   } else {
1720     TFLITE_DCHECK(false);
1721     return;
1722   }
1723 
1724   static constexpr int kKernelRows = 4;
1725   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1726       gemm_context->max_num_threads(), output_depth, batches, accum_depth);
1727   if (thread_count == 1) {
1728     // Single-thread case: do the computation on the current thread, don't
1729     // use a threadpool
1730     ShuffledFullyConnectedWorkerImpl(
1731         shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
1732         output_depth, output_depth, accum_depth, bias_data, output_multiplier,
1733         output_shift, output_data);
1734     return;
1735   }
1736 
1737   // Multi-threaded case: use the gemmlowp context's threadpool.
1738   TFLITE_DCHECK_GT(thread_count, 1);
1739   std::vector<gemmlowp::Task*> tasks(thread_count);
1740   const int kRowsPerWorker =
1741       gemmlowp::RoundUp<kKernelRows>(output_depth / thread_count);
1742   int row_start = 0;
1743   for (int i = 0; i < thread_count; i++) {
1744     int row_end = std::min(output_depth, row_start + kRowsPerWorker);
1745     tasks[i] = new ShuffledFullyConnectedWorkerTask(
1746         shuffled_input_workspace_data,
1747         int8_shuffled_weights_data + row_start * accum_depth, batches,
1748         row_end - row_start, output_depth, accum_depth, bias_data + row_start,
1749         output_multiplier, output_shift, output_data + row_start);
1750     row_start = row_end;
1751   }
1752   TFLITE_DCHECK_EQ(row_start, output_depth);
1753   gemm_context->workers_pool()->Execute(tasks);
1754 }
1755 
MeanImpl(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,int start_depth,int end_depth)1756 inline void MeanImpl(const tflite::MeanParams& op_params,
1757                      const RuntimeShape& input_shape, const uint8_t* input_data,
1758                      int32 input_zero_point, float input_scale,
1759                      const RuntimeShape& output_shape, uint8_t* output_data,
1760                      int32 output_zero_point, float output_scale,
1761                      int start_depth, int end_depth) {
1762   gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8/MeanImpl");
1763 
1764   // Current implementation only supports dimension equals 4 and simultaneous
1765   // reduction over width and height.
1766   const int output_batch = output_shape.Dims(0);
1767   const int output_height = output_shape.Dims(2);
1768   const int output_width = output_shape.Dims(2);
1769   const int input_height = input_shape.Dims(1);
1770   const int input_width = input_shape.Dims(2);
1771   const float num_elements_in_axis = input_width * input_height;
1772 
1773   TFLITE_DCHECK_EQ(op_params.axis_count, 2);
1774   TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1775                 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1776   TFLITE_DCHECK_EQ(output_height, 1);
1777   TFLITE_DCHECK_EQ(output_width, 1);
1778 
1779   const bool ordinary_mean =
1780       (input_zero_point == output_zero_point && input_scale == output_scale);
1781   float scale, bias;
1782   if (!ordinary_mean) {
1783     scale = input_scale / output_scale;
1784     bias = -input_zero_point * scale + 0.5;
1785   }
1786 
1787 #ifdef USE_NEON
1788   const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
1789   // This is only an approximation as NEON does not offer division instruction.
1790   const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
1791   const float32x4_t kRounding = vdupq_n_f32(0.5);
1792   float32x4_t bias_dup;
1793   float32x4_t output_zero_point_dup;
1794   if (!ordinary_mean) {
1795     bias_dup = vdupq_n_f32(bias);
1796     output_zero_point_dup = vdupq_n_f32(output_zero_point);
1797   }
1798 #endif
1799 
1800   for (int out_b = 0; out_b < output_batch; ++out_b) {
1801     int out_d = start_depth;
1802 #ifdef USE_NEON
1803 
1804     for (; out_d < end_depth - 8; out_d += 8) {
1805       float32x4_t temp_sum_1 = vdupq_n_f32(0);
1806       float32x4_t temp_sum_2 = vdupq_n_f32(0);
1807       for (int in_h = 0; in_h < input_height; ++in_h) {
1808         for (int in_w = 0; in_w < input_width; ++in_w) {
1809           const uint8_t* input_data_ptr =
1810               input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
1811           uint8x8_t input_data_val = vld1_u8(input_data_ptr);
1812           int16x8_t input_data_val_shift =
1813               vreinterpretq_s16_u16(vmovl_u8(input_data_val));
1814           float32x4_t input_float_1 =
1815               vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift)));
1816           float32x4_t input_float_2 =
1817               vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift)));
1818           temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1);
1819           temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2);
1820         }
1821       }
1822 
1823       float32x4_t mean_1 = vmulq_f32(temp_sum_1, num_elements_reverse);
1824       float32x4_t mean_2 = vmulq_f32(temp_sum_2, num_elements_reverse);
1825 
1826       if (!ordinary_mean) {
1827         // maq is not supported, break down into two ops.
1828         mean_1 = vmulq_n_f32(mean_1, scale);
1829         mean_1 = vaddq_f32(mean_1, bias_dup);
1830         mean_2 = vmulq_n_f32(mean_2, scale);
1831         mean_2 = vaddq_f32(mean_2, bias_dup);
1832       }
1833 
1834       if (!ordinary_mean) {
1835         mean_1 = vaddq_f32(mean_1, output_zero_point_dup);
1836         mean_2 = vaddq_f32(mean_2, output_zero_point_dup);
1837       }
1838 
1839       // Rounding.
1840       mean_1 = vaddq_f32(mean_1, kRounding);
1841       mean_2 = vaddq_f32(mean_2, kRounding);
1842       uint32x4_t casted_mean_1 = vcvtq_u32_f32(mean_1);
1843       uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1);
1844       uint32x4_t casted_mean_2 = vcvtq_u32_f32(mean_2);
1845       uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2);
1846       uint16x8_t combined_mean =
1847           vcombine_u16(narrow_range_mean_2, narrow_range_mean_1);
1848       uint8x8_t narrowed_combined_mean = vmovn_u16(combined_mean);
1849       uint8_t* output_data_ptr =
1850           output_data + Offset(output_shape, out_b, 0, 0, out_d);
1851       vst1_u8(output_data_ptr, narrowed_combined_mean);
1852     }
1853 #endif
1854 
1855     for (; out_d < end_depth; ++out_d) {
1856       float temp_value = 0;
1857       for (int in_h = 0; in_h < input_height; ++in_h) {
1858         for (int in_w = 0; in_w < input_width; ++in_w) {
1859           temp_value +=
1860               input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
1861         }
1862       }
1863 
1864       temp_value = temp_value / num_elements_in_axis;
1865       if (ordinary_mean) {
1866         output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1867             static_cast<uint8_t>(round(temp_value));
1868       } else {
1869         output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1870             static_cast<uint8_t>(round(temp_value * scale + bias)) +
1871             output_zero_point;
1872       }
1873     }
1874   }
1875 }
1876 
1877 struct MeanWorkerTask : public gemmlowp::Task {
MeanWorkerTaskMeanWorkerTask1878   MeanWorkerTask(const tflite::MeanParams& op_params,
1879                  const RuntimeShape& input_shape, const uint8_t* input_data,
1880                  int32 input_zero_point, float input_scale,
1881                  const RuntimeShape& output_shape, uint8_t* output_data,
1882                  int32 output_zero_point, float output_scale, int start_height,
1883                  int end_height)
1884       : op_params_(op_params),
1885         input_shape_(input_shape),
1886         input_data_(input_data),
1887         input_zero_point_(input_zero_point),
1888         input_scale_(input_scale),
1889         output_shape_(output_shape),
1890         output_data_(output_data),
1891         output_zero_point_(output_zero_point),
1892         output_scale_(output_scale),
1893         start_height_(start_height),
1894         end_height_(end_height) {}
1895 
RunMeanWorkerTask1896   void Run() override {
1897     MeanImpl(op_params_, input_shape_, input_data_, input_zero_point_,
1898              input_scale_, output_shape_, output_data_, output_zero_point_,
1899              output_scale_, start_height_, end_height_);
1900   }
1901 
1902  private:
1903   const tflite::MeanParams& op_params_;
1904   const RuntimeShape& input_shape_;
1905   const uint8_t* input_data_;
1906   int32 input_zero_point_;
1907   float input_scale_;
1908   const RuntimeShape& output_shape_;
1909   uint8_t* output_data_;
1910   int32 output_zero_point_;
1911   float output_scale_;
1912   int start_height_;
1913   int end_height_;
1914   gemmlowp::GemmContext* gemm_context_;
1915 };
1916 
Mean(const tflite::MeanParams & op_params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & unextended_output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,gemmlowp::GemmContext * gemm_context)1917 inline void Mean(const tflite::MeanParams& op_params,
1918                  const RuntimeShape& unextended_input_shape,
1919                  const uint8_t* input_data, int32 input_zero_point,
1920                  float input_scale, const RuntimeShape& unextended_output_shape,
1921                  uint8_t* output_data, int32 output_zero_point,
1922                  float output_scale, gemmlowp::GemmContext* gemm_context) {
1923   gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8");
1924 
1925   // Current implementation only supports dimension equals 4 and simultaneous
1926   // reduction over width and height.
1927   TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
1928   TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1929   const RuntimeShape input_shape =
1930       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1931   const RuntimeShape output_shape =
1932       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1933   const int output_height = output_shape.Dims(1);
1934   const int output_width = output_shape.Dims(2);
1935   const int output_depth = output_shape.Dims(3);
1936 
1937   TFLITE_DCHECK_EQ(op_params.axis_count, 2);
1938   TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1939                 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1940   TFLITE_DCHECK_EQ(output_height, 1);
1941   TFLITE_DCHECK_EQ(output_width, 1);
1942 
1943   constexpr int kMinDepthPerThread = 8;
1944   int thread_count = output_depth / kMinDepthPerThread;
1945   thread_count = thread_count > 0 ? thread_count : 1;
1946   const int capped_thread_count =
1947       std::min(thread_count, gemm_context->max_num_threads());
1948 
1949   if (thread_count == 1) {
1950     MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
1951              output_shape, output_data, output_zero_point, output_scale, 0,
1952              output_depth);
1953   } else {
1954     // Instead parrallel for batch, we loop for the output_depth since batch
1955     // is typical 1.
1956     std::vector<gemmlowp::Task*> tasks(capped_thread_count);
1957     int depth_start = 0;
1958     for (int i = 0; i < capped_thread_count; ++i) {
1959       // Try to distribute the tasks as even as possible.
1960       int depth_end = depth_start +
1961                       (output_depth - depth_start) / (capped_thread_count - i);
1962       tasks[i] = new MeanWorkerTask(op_params, input_shape, input_data,
1963                                     input_zero_point, input_scale, output_shape,
1964                                     output_data, output_zero_point,
1965                                     output_scale, depth_start, depth_end);
1966       depth_start = depth_end;
1967     }
1968     gemm_context->workers_pool()->Execute(tasks);
1969   }
1970 }
1971 
1972 template <typename T>
ExtractPatchIntoBufferColumn(const RuntimeShape & input_shape,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)1973 inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
1974                                          int h, int b, int kheight, int kwidth,
1975                                          int stride_width, int stride_height,
1976                                          int pad_width, int pad_height,
1977                                          int in_width, int in_height,
1978                                          int in_depth, int single_buffer_length,
1979                                          int buffer_id, const T* in_data,
1980                                          T* conv_buffer_data, uint8 zero_byte) {
1981   gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
1982   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1983   // This chunk of code reshapes all the inputs corresponding to
1984   // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
1985   const int kwidth_times_indepth = kwidth * in_depth;
1986   const int inwidth_times_indepth = in_width * in_depth;
1987   const int ih_ungated_start = h * stride_height - pad_height;
1988   const int ih_ungated_end = (ih_ungated_start + kheight);
1989   const int ih_end = std::min(ih_ungated_end, in_height);
1990   const int iw_ungated_start = w * stride_width - pad_width;
1991   const int iw_ungated_end = (iw_ungated_start + kwidth);
1992   const int iw_end = std::min(iw_ungated_end, in_width);
1993   // If the patch is off the edge of the input image, skip writing those rows
1994   // and columns from the patch into the output array.
1995   const int h_offset = std::max(0, -ih_ungated_start);
1996   const int w_offset = std::max(0, -iw_ungated_start);
1997   const int ih_start = std::max(0, ih_ungated_start);
1998   const int iw_start = std::max(0, iw_ungated_start);
1999   const int single_row_num =
2000       std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
2001   const int output_row_offset = (buffer_id * single_buffer_length);
2002   int out_offset =
2003       output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
2004   int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
2005 
2006   // Express all of the calculations as padding around the input patch.
2007   const int top_padding = h_offset;
2008   const int bottom_padding = (ih_ungated_end - ih_end);
2009   const int left_padding = w_offset;
2010   const int right_padding = (iw_ungated_end - iw_end);
2011   assert(single_row_num ==
2012          ((kwidth - (left_padding + right_padding)) * in_depth));
2013 
2014   // Write out zeroes to the elements representing the top rows of the input
2015   // patch that are off the edge of the input image.
2016   if (top_padding > 0) {
2017     const int top_row_elements = (top_padding * kwidth * in_depth);
2018     memset(conv_buffer_data + output_row_offset, zero_byte,
2019            (top_row_elements * sizeof(T)));
2020   }
2021 
2022   // If the patch is on the interior of the input image horizontally, just copy
2023   // over the rows sequentially, otherwise add zero padding at the start or end.
2024   if ((left_padding == 0) && (right_padding == 0)) {
2025     for (int ih = ih_start; ih < ih_end; ++ih) {
2026       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
2027              single_row_num * sizeof(T));
2028       out_offset += kwidth_times_indepth;
2029       in_offset += inwidth_times_indepth;
2030     }
2031   } else {
2032     for (int ih = ih_start; ih < ih_end; ++ih) {
2033       if (left_padding > 0) {
2034         const int left_start = (out_offset - (left_padding * in_depth));
2035         memset(conv_buffer_data + left_start, zero_byte,
2036                (left_padding * in_depth * sizeof(T)));
2037       }
2038       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
2039              single_row_num * sizeof(T));
2040       if (right_padding > 0) {
2041         const int right_start = (out_offset + single_row_num);
2042         memset(conv_buffer_data + right_start, zero_byte,
2043                (right_padding * in_depth * sizeof(T)));
2044       }
2045       out_offset += kwidth_times_indepth;
2046       in_offset += inwidth_times_indepth;
2047     }
2048   }
2049 
2050   // If the bottom of the patch falls off the input image, pad the values
2051   // representing those input rows with zeroes.
2052   if (bottom_padding > 0) {
2053     const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
2054     const int bottom_start =
2055         output_row_offset +
2056         ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
2057     memset(conv_buffer_data + bottom_start, zero_byte,
2058            (bottom_row_elements * sizeof(T)));
2059   }
2060 }
2061 
2062 template <typename T>
DilatedIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)2063 void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
2064                    const RuntimeShape& input_shape, const T* input_data,
2065                    const RuntimeShape& filter_shape,
2066                    const RuntimeShape& output_shape, T* im2col_data) {
2067   const int stride_width = params.stride_width;
2068   const int stride_height = params.stride_height;
2069   const int dilation_width_factor = params.dilation_width_factor;
2070   const int dilation_height_factor = params.dilation_height_factor;
2071   const int pad_width = params.padding_values.width;
2072   const int pad_height = params.padding_values.height;
2073   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2074   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2075   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2076 
2077   // For dilated convolution, the input pixels are not contiguous therefore we
2078   // can't use the same opitimizations as Im2Col(). Though note this code would
2079   // work fine for the non-dilated case too (though likely a bit slower).
2080   gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
2081   TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
2082   TFLITE_DCHECK(im2col_data);
2083   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2084   const int input_height = input_shape.Dims(1);
2085   const int input_width = input_shape.Dims(2);
2086   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
2087   const int filter_height = filter_shape.Dims(1);
2088   const int filter_width = filter_shape.Dims(2);
2089   const int output_height = output_shape.Dims(1);
2090   const int output_width = output_shape.Dims(2);
2091   MatchingDim(output_shape, 3, filter_shape, 0);
2092 
2093   // Construct the MxN sized im2col matrix.
2094   // The rows M, are sub-ordered B x H x W
2095   const RuntimeShape row_shape({1, batches, output_height, output_width});
2096   // The columns, N, are sub-ordered Kh x Kw x Din
2097   const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
2098   // Use dimensions M and N to construct dims for indexing directly into im2col
2099   const RuntimeShape im2col_shape(
2100       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
2101 
2102   // Loop through the output rows (B x H x W)
2103   for (int batch = 0; batch < batches; ++batch) {
2104     for (int out_y = 0; out_y < output_height; ++out_y) {
2105       for (int out_x = 0; out_x < output_width; ++out_x) {
2106         // Each im2col row is an output pixel. Arrange the input data in this
2107         // row in an order we can conveniently multiply with the filter data.
2108         int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
2109         const int in_x_origin = (out_x * stride_width) - pad_width;
2110         const int in_y_origin = (out_y * stride_height) - pad_height;
2111         // Loop through all the pixels of the filter (Kh x Kw)
2112         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
2113           const int in_y = in_y_origin + dilation_height_factor * filter_y;
2114           if ((in_y >= 0) && (in_y < input_height)) {
2115             // Filter row is within the input data.
2116             // Loop through all the filter pixels in this row.
2117             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
2118               const int in_x = in_x_origin + dilation_width_factor * filter_x;
2119               int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
2120               T* dst = im2col_data +
2121                        Offset(im2col_shape, 0, 0, row_offset, col_offset);
2122               if ((in_x >= 0) && (in_x < input_width)) {
2123                 // Filter pixel is within the input, copy the input data.
2124                 T const* src =
2125                     input_data + Offset(input_shape, batch, in_y, in_x, 0);
2126                 memcpy(dst, src, input_depth * sizeof(T));
2127               } else {
2128                 // Filter pixel is outside the input, zero it out.
2129                 memset(dst, zero_byte, input_depth * sizeof(T));
2130               }
2131             }
2132           } else {
2133             // Filter row is outside the input, zero out the entire filter row.
2134             int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
2135             T* dst = im2col_data +
2136                      Offset(im2col_shape, 0, 0, row_offset, col_offset);
2137             memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
2138           }
2139         }
2140       }
2141     }
2142   }
2143 }
2144 
2145 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)2146 void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
2147             const RuntimeShape& input_shape, const T* input_data,
2148             const RuntimeShape& output_shape, T* output_data) {
2149   gemmlowp::ScopedProfilingLabel label("Im2col");
2150   const int stride_width = params.stride_width;
2151   const int stride_height = params.stride_height;
2152   const int pad_width = params.padding_values.width;
2153   const int pad_height = params.padding_values.height;
2154   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2155   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2156 
2157   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2158   const int input_depth = input_shape.Dims(3);
2159   const int input_width = input_shape.Dims(2);
2160   const int input_height = input_shape.Dims(1);
2161   const int output_depth = output_shape.Dims(3);
2162   const int output_width = output_shape.Dims(2);
2163   const int output_height = output_shape.Dims(1);
2164 
2165   int buffer_id = 0;
2166   // Loop over the output nodes.
2167   for (int b = 0; b < batches; ++b) {
2168     for (int h = 0; h < output_height; ++h) {
2169       for (int w = 0; w < output_width; ++w) {
2170         ExtractPatchIntoBufferColumn(
2171             input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
2172             pad_width, pad_height, input_width, input_height, input_depth,
2173             output_depth, buffer_id, input_data, output_data, zero_byte);
2174         ++buffer_id;
2175       }
2176     }
2177   }
2178 }
2179 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2180 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2181                  const float* input_data, const RuntimeShape& filter_shape,
2182                  const float* filter_data, const RuntimeShape& bias_shape,
2183                  const float* bias_data, const RuntimeShape& output_shape,
2184                  float* output_data, const RuntimeShape& im2col_shape,
2185                  float* im2col_data) {
2186   const int stride_width = params.stride_width;
2187   const int stride_height = params.stride_height;
2188   const int dilation_width_factor = params.dilation_width_factor;
2189   const int dilation_height_factor = params.dilation_height_factor;
2190   const float output_activation_min = params.float_activation_min;
2191   const float output_activation_max = params.float_activation_max;
2192   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2193   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2194   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2195 
2196   (void)im2col_data;
2197   (void)im2col_shape;
2198   gemmlowp::ScopedProfilingLabel label("Conv");
2199 
2200   // NB: static_cast<float>(0x00000000h) == 0.0f
2201   const uint8 float_zero_byte = 0x00;
2202   const float* gemm_input_data = nullptr;
2203   const RuntimeShape* gemm_input_shape = nullptr;
2204   const int filter_width = filter_shape.Dims(2);
2205   const int filter_height = filter_shape.Dims(1);
2206   const bool need_dilated_im2col =
2207       dilation_width_factor != 1 || dilation_height_factor != 1;
2208   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2209                            filter_width != 1 || filter_height != 1;
2210   if (need_dilated_im2col) {
2211     DilatedIm2col(params, float_zero_byte, input_shape, input_data,
2212                   filter_shape, output_shape, im2col_data);
2213     gemm_input_data = im2col_data;
2214     gemm_input_shape = &im2col_shape;
2215   } else if (need_im2col) {
2216     TFLITE_DCHECK(im2col_data);
2217     Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
2218            input_data, im2col_shape, im2col_data);
2219     gemm_input_data = im2col_data;
2220     gemm_input_shape = &im2col_shape;
2221   } else {
2222     // TODO(aselle): We need to make sure to not send im2col if it is not
2223     // needed.
2224     TFLITE_DCHECK(!im2col_data);
2225     gemm_input_data = input_data;
2226     gemm_input_shape = &input_shape;
2227   }
2228 
2229   // The following code computes matrix multiplication c = a * transponse(b)
2230   // with CBLAS, where:
2231   // * `a` is a matrix with dimensions (m, k).
2232   // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
2233   // * `c` is a matrix with dimensions (m, n).
2234   // The naming of variables are aligned with CBLAS specification here.
2235   const float* a = gemm_input_data;
2236   const float* b = filter_data;
2237   float* c = output_data;
2238   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
2239   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
2240   int n = output_shape.Dims(3);
2241   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
2242 
2243 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2244   // The stride of matrix a, b and c respectively.
2245   int stride_a = k;
2246   int stride_b = k;
2247   int stride_c = n;
2248 
2249   cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
2250               stride_a, b, stride_b, 0.0f, c, stride_c);
2251 #else
2252   // When an optimized CBLAS implementation is not available, fall back
2253   // to using Eigen.
2254   typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
2255       Matrix;
2256   typedef Eigen::Map<Matrix> MatrixRef;
2257   typedef Eigen::Map<const Matrix> ConstMatrixRef;
2258 
2259   MatrixRef matrix_c(c, m, n);
2260   ConstMatrixRef matrix_a(a, m, k);
2261   ConstMatrixRef matrix_b(b, n, k);
2262 
2263   // The following special casing for when a or b is a vector is required
2264   // as Eigen seem to fail to make this optimization on its own.
2265   if (n == 1) {
2266     gemmlowp::ScopedProfilingLabel label("GEMV");
2267     matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
2268   } else if (m == 1) {
2269     gemmlowp::ScopedProfilingLabel label("GEMV");
2270     matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
2271   } else {
2272     gemmlowp::ScopedProfilingLabel label("GEMM");
2273     matrix_c.noalias() = matrix_a * matrix_b.transpose();
2274   }
2275 
2276 #endif  //  defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2277 
2278   optimized_ops::AddBiasAndEvalActivationFunction(
2279       output_activation_min, output_activation_max, bias_shape, bias_data,
2280       output_shape, output_data);
2281 }
2282 
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data)2283 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
2284                        const RuntimeShape& input_shape,
2285                        const int8_t* input_data,
2286                        const RuntimeShape& filter_shape,
2287                        const int8_t* filter_data,
2288                        const RuntimeShape& bias_shape, const float* bias_data,
2289                        const RuntimeShape& output_shape, float* output_data,
2290                        const RuntimeShape& im2col_shape, int8_t* im2col_data) {
2291   const int stride_width = params.stride_width;
2292   const int stride_height = params.stride_height;
2293   const float output_activation_min = params.float_activation_min;
2294   const float output_activation_max = params.float_activation_max;
2295   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2296   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2297   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2298 
2299   const int batch_size = input_shape.Dims(0);
2300   const int filter_width = filter_shape.Dims(2);
2301   const int filter_height = filter_shape.Dims(1);
2302 
2303   const int8_t* gemm_input_data = nullptr;
2304   int num_input;
2305   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2306                            filter_width != 1 || filter_height != 1;
2307 
2308   if (need_im2col) {
2309     TFLITE_DCHECK(im2col_data);
2310     // symmetric quantization assumes zero point of 0.
2311     const int input_zero_point = 0;
2312 
2313     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2314            input_data, im2col_shape, im2col_data);
2315     gemm_input_data = im2col_data;
2316     num_input = im2col_shape.FlatSize();
2317   } else {
2318     TFLITE_DCHECK(!im2col_data);
2319     gemm_input_data = input_data;
2320     num_input = input_shape.FlatSize();
2321   }
2322 
2323   // Flatten 4D matrices into 2D matrices for matrix multiplication.
2324 
2325   // Flatten so that each filter has its own row.
2326   const int filter_rows = filter_shape.Dims(0);
2327   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2328 
2329   // In MatrixBatchVectorMultiplyAccumulate, each output value is the
2330   // dot product of one row of the first matrix with one row of the second
2331   // matrix. Therefore, the number of cols in each matrix are equivalent.
2332   //
2333   // After Im2Col, each input patch becomes a row.
2334   const int gemm_input_cols = filter_cols;
2335   const int gemm_input_rows = num_input / gemm_input_cols;
2336 
2337   const int output_cols = output_shape.Dims(3);
2338   const int output_rows = FlatSizeSkipDim(output_shape, 3);
2339   TFLITE_DCHECK_EQ(output_cols, filter_rows);
2340   TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
2341   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
2342 
2343   // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
2344   // input matrix has its own scale factor. This code duplicates the scale
2345   // factors for each row in the same batch.
2346   const int rows_per_batch = gemm_input_rows / batch_size;
2347   for (int i = gemm_input_rows - 1; i >= 0; --i) {
2348     scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
2349   }
2350 
2351   tensor_utils::ZeroVector(output_data, output_rows * output_cols);
2352 
2353   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
2354       filter_data, filter_rows, filter_cols, gemm_input_data,
2355       scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
2356       /*result_stride=*/1);
2357 
2358   AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
2359                                    bias_shape, bias_data, output_shape,
2360                                    output_data);
2361 }
2362 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,gemmlowp::GemmContext * gemm_context)2363 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2364                  const uint8* input_data, const RuntimeShape& filter_shape,
2365                  const uint8* filter_data, const RuntimeShape& bias_shape,
2366                  const int32* bias_data, const RuntimeShape& output_shape,
2367                  uint8* output_data, const RuntimeShape& im2col_shape,
2368                  uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
2369   gemmlowp::ScopedProfilingLabel label("Conv/8bit");
2370   const int stride_width = params.stride_width;
2371   const int stride_height = params.stride_height;
2372   const int dilation_width_factor = params.dilation_width_factor;
2373   const int dilation_height_factor = params.dilation_height_factor;
2374   const int32 input_offset = params.input_offset;
2375   const int32 filter_offset = params.weights_offset;
2376   const int32 output_offset = params.output_offset;
2377   const int32 output_multiplier = params.output_multiplier;
2378   const int output_shift = params.output_shift;
2379   const int32 output_activation_min = params.quantized_activation_min;
2380   const int32 output_activation_max = params.quantized_activation_max;
2381   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2382   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2383   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2384 
2385   const uint8* gemm_input_data = nullptr;
2386   const RuntimeShape* gemm_input_shape = nullptr;
2387   const int filter_width = filter_shape.Dims(2);
2388   const int filter_height = filter_shape.Dims(1);
2389   const bool need_dilated_im2col =
2390       dilation_width_factor != 1 || dilation_height_factor != 1;
2391   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2392                            filter_width != 1 || filter_height != 1;
2393   if (need_dilated_im2col) {
2394     TFLITE_DCHECK(im2col_data);
2395     const int input_zero_point = -input_offset;
2396     TFLITE_DCHECK_GE(input_zero_point, 0);
2397     TFLITE_DCHECK_LE(input_zero_point, 255);
2398     DilatedIm2col(params, input_zero_point, input_shape, input_data,
2399                   filter_shape, output_shape, im2col_data);
2400     gemm_input_data = im2col_data;
2401     gemm_input_shape = &im2col_shape;
2402   } else if (need_im2col) {
2403     TFLITE_DCHECK(im2col_data);
2404     const int input_zero_point = -input_offset;
2405     TFLITE_DCHECK_GE(input_zero_point, 0);
2406     TFLITE_DCHECK_LE(input_zero_point, 255);
2407     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2408            input_data, im2col_shape, im2col_data);
2409     gemm_input_data = im2col_data;
2410     gemm_input_shape = &im2col_shape;
2411   } else {
2412     TFLITE_DCHECK(!im2col_data);
2413     gemm_input_data = input_data;
2414     gemm_input_shape = &input_shape;
2415   }
2416 
2417   const int gemm_input_rows = gemm_input_shape->Dims(3);
2418   // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
2419   // The root cause has not yet been identified though. Same applies below for
2420   // the other calls commented out. This is a partial rollback of cl/196819423.
2421   // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
2422   const int gemm_input_cols = gemm_input_shape->Dims(0) *
2423                               gemm_input_shape->Dims(1) *
2424                               gemm_input_shape->Dims(2);
2425   const int filter_rows = filter_shape.Dims(0);
2426   // See b/79927784.
2427   // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2428   const int filter_cols =
2429       filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
2430   const int output_rows = output_shape.Dims(3);
2431   // See b/79927784.
2432   // const int output_cols = FlatSizeSkipDim(output_shape, 3);
2433   const int output_cols =
2434       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
2435   TFLITE_DCHECK_EQ(output_rows, filter_rows);
2436   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
2437   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
2438   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2439 
2440 #ifdef USE_NEON
2441   if (gemm_input_cols == 1 && output_rows >= 4) {
2442     RuntimeShape fc_filter_shape{
2443         filter_shape.Dims(0),
2444         filter_shape.Dims(filter_shape.DimensionsCount() - 1)};
2445 
2446     return FullyConnectedAsGEMV(
2447         *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
2448         filter_data, filter_offset, bias_shape, bias_data, output_offset,
2449         output_multiplier, output_shift, output_activation_min,
2450         output_activation_max, output_shape, output_data, gemm_context);
2451   }
2452 #endif
2453 
2454   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2455       filter_data, filter_rows, filter_cols);
2456   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2457       gemm_input_data, gemm_input_rows, gemm_input_cols);
2458   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2459       output_data, output_rows, output_cols);
2460   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2461       bias_data, output_rows, output_offset, output_multiplier, output_shift,
2462       output_activation_min, output_activation_max);
2463   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2464                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2465       gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
2466       input_offset, output_pipeline);
2467 }
2468 
2469 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)2470 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
2471                          const RuntimeShape& unextended_input_shape,
2472                          const T* input_data,
2473                          const RuntimeShape& unextended_output_shape,
2474                          T* output_data) {
2475   gemmlowp::ScopedProfilingLabel label("DepthToSpace");
2476 
2477   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2478   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
2479   const RuntimeShape input_shape =
2480       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2481   const RuntimeShape output_shape =
2482       RuntimeShape::ExtendedShape(4, unextended_output_shape);
2483 
2484   const int input_depth = input_shape.Dims(3);
2485   const int input_width = input_shape.Dims(2);
2486   const int input_height = input_shape.Dims(1);
2487 
2488   const int output_depth = output_shape.Dims(3);
2489   const int batch_size = output_shape.Dims(0);
2490 
2491   // Number of continuous values that we can copy in one interation.
2492   const int stride = op_params.block_size * output_depth;
2493 
2494   for (int batch = 0; batch < batch_size; ++batch) {
2495     for (int in_h = 0; in_h < input_height; ++in_h) {
2496       const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
2497       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
2498         const T* src = input_ptr;
2499         for (int in_w = 0; in_w < input_width; ++in_w) {
2500           memcpy(output_data, src, stride * sizeof(T));
2501           output_data += stride;
2502           src += input_depth;
2503         }
2504         input_ptr += stride;
2505       }
2506     }
2507   }
2508 }
2509 
2510 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)2511 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
2512                          const RuntimeShape& unextended_input_shape,
2513                          const T* input_data,
2514                          const RuntimeShape& unextended_output_shape,
2515                          T* output_data) {
2516   gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
2517 
2518   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2519   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
2520   const RuntimeShape input_shape =
2521       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2522   const RuntimeShape output_shape =
2523       RuntimeShape::ExtendedShape(4, unextended_output_shape);
2524 
2525   const int output_depth = output_shape.Dims(3);
2526   const int output_width = output_shape.Dims(2);
2527   const int output_height = output_shape.Dims(1);
2528 
2529   const int input_depth = input_shape.Dims(3);
2530   const int batch_size = input_shape.Dims(0);
2531 
2532   // Number of continuous values that we can copy in one interation.
2533   const int stride = op_params.block_size * input_depth;
2534 
2535   for (int batch = 0; batch < batch_size; ++batch) {
2536     for (int out_h = 0; out_h < output_height; ++out_h) {
2537       T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
2538       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
2539         T* dst = output_ptr;
2540         for (int out_w = 0; out_w < output_width; ++out_w) {
2541           memcpy(dst, input_data, stride * sizeof(T));
2542           input_data += stride;
2543           dst += output_depth;
2544         }
2545         output_ptr += stride;
2546       }
2547     }
2548   }
2549 }
2550 
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)2551 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
2552                  const RuntimeShape& output_shape, float* output_data) {
2553   gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
2554 
2555   const auto input = MapAsVector(input_data, input_shape);
2556   auto output = MapAsVector(output_data, output_shape);
2557   output = input.cwiseMax(0.0f);
2558 }
2559 
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)2560 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
2561                             const RuntimeShape& input_shape,
2562                             const float* input_data,
2563                             const RuntimeShape& output_shape,
2564                             float* output_data) {
2565   gemmlowp::ScopedProfilingLabel label("L2Normalization");
2566   const int trailing_dim = input_shape.DimensionsCount() - 1;
2567   const int outer_size =
2568       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
2569   const int depth =
2570       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
2571   for (int i = 0; i < outer_size; ++i) {
2572     float squared_l2_norm = 0;
2573     for (int c = 0; c < depth; ++c) {
2574       const float val = input_data[c];
2575       squared_l2_norm += val * val;
2576     }
2577     const float l2_norm = std::sqrt(squared_l2_norm);
2578     for (int c = 0; c < depth; ++c) {
2579       *output_data = *input_data / l2_norm;
2580       ++output_data;
2581       ++input_data;
2582     }
2583   }
2584 }
2585 
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)2586 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
2587                             const RuntimeShape& input_shape,
2588                             const uint8* input_data,
2589                             const RuntimeShape& output_shape,
2590                             uint8* output_data) {
2591   gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
2592   const int trailing_dim = input_shape.DimensionsCount() - 1;
2593   const int depth =
2594       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
2595   const int outer_size =
2596       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
2597   const int32 input_zero_point = op_params.input_zero_point;
2598   for (int i = 0; i < outer_size; ++i) {
2599     int32 square_l2_norm = 0;
2600     for (int c = 0; c < depth; c++) {
2601       // Note that input_data advances by depth in the second pass below.
2602       int32 diff = input_data[c] - input_zero_point;
2603       square_l2_norm += diff * diff;
2604     }
2605     int32 inv_l2norm_multiplier;
2606     int inv_l2norm_shift;
2607     GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
2608                                      &inv_l2norm_multiplier, &inv_l2norm_shift);
2609 
2610     for (int c = 0; c < depth; c++) {
2611       int32 diff = *input_data - input_zero_point;
2612       int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
2613           128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
2614       int32 unclamped_output_val = 128 + rescaled_diff;
2615       int32 output_val = std::min(255, std::max(0, unclamped_output_val));
2616       *output_data = static_cast<uint8>(output_val);
2617       ++input_data;
2618       ++output_data;
2619     }
2620   }
2621 }
2622 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2623 inline void Add(const ArithmeticParams& params,
2624                 const RuntimeShape& input1_shape, const float* input1_data,
2625                 const RuntimeShape& input2_shape, const float* input2_data,
2626                 const RuntimeShape& output_shape, float* output_data) {
2627   gemmlowp::ScopedProfilingLabel label("Add");
2628 
2629   int i = 0;
2630   const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
2631 #ifdef USE_NEON
2632   const auto activation_min = vdupq_n_f32(params.float_activation_min);
2633   const auto activation_max = vdupq_n_f32(params.float_activation_max);
2634   for (; i <= size - 16; i += 16) {
2635     auto a10 = vld1q_f32(input1_data + i);
2636     auto a11 = vld1q_f32(input1_data + i + 4);
2637     auto a12 = vld1q_f32(input1_data + i + 8);
2638     auto a13 = vld1q_f32(input1_data + i + 12);
2639     auto a20 = vld1q_f32(input2_data + i);
2640     auto a21 = vld1q_f32(input2_data + i + 4);
2641     auto a22 = vld1q_f32(input2_data + i + 8);
2642     auto a23 = vld1q_f32(input2_data + i + 12);
2643     auto x0 = vaddq_f32(a10, a20);
2644     auto x1 = vaddq_f32(a11, a21);
2645     auto x2 = vaddq_f32(a12, a22);
2646     auto x3 = vaddq_f32(a13, a23);
2647     x0 = vmaxq_f32(activation_min, x0);
2648     x1 = vmaxq_f32(activation_min, x1);
2649     x2 = vmaxq_f32(activation_min, x2);
2650     x3 = vmaxq_f32(activation_min, x3);
2651     x0 = vminq_f32(activation_max, x0);
2652     x1 = vminq_f32(activation_max, x1);
2653     x2 = vminq_f32(activation_max, x2);
2654     x3 = vminq_f32(activation_max, x3);
2655     vst1q_f32(output_data + i, x0);
2656     vst1q_f32(output_data + i + 4, x1);
2657     vst1q_f32(output_data + i + 8, x2);
2658     vst1q_f32(output_data + i + 12, x3);
2659   }
2660   for (; i <= size - 4; i += 4) {
2661     auto a1 = vld1q_f32(input1_data + i);
2662     auto a2 = vld1q_f32(input2_data + i);
2663     auto x = vaddq_f32(a1, a2);
2664     x = vmaxq_f32(activation_min, x);
2665     x = vminq_f32(activation_max, x);
2666     vst1q_f32(output_data + i, x);
2667   }
2668 #endif  // NEON
2669 
2670   for (; i < size; i++) {
2671     auto x = input1_data[i] + input2_data[i];
2672     output_data[i] = ActivationFunctionWithMinMax(
2673         x, params.float_activation_min, params.float_activation_max);
2674   }
2675 }
2676 
2677 // Element-wise add that can often be used for inner loop of broadcast add as
2678 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2679 inline void AddElementwise(int size, const ArithmeticParams& params,
2680                            const uint8* input1_data, const uint8* input2_data,
2681                            uint8* output_data) {
2682   gemmlowp::ScopedProfilingLabel label("AddElementwise/8bit");
2683   int i = 0;
2684   TFLITE_DCHECK_GT(params.input1_offset, -256);
2685   TFLITE_DCHECK_GT(params.input2_offset, -256);
2686   TFLITE_DCHECK_LT(params.input1_offset, 256);
2687   TFLITE_DCHECK_LT(params.input2_offset, 256);
2688 #ifdef USE_NEON
2689   const uint8x8_t output_activation_min_vector =
2690       vdup_n_u8(params.quantized_activation_min);
2691   const uint8x8_t output_activation_max_vector =
2692       vdup_n_u8(params.quantized_activation_max);
2693   for (; i <= size - 8; i += 8) {
2694     const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
2695     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
2696     const int16x8_t input1_val_s16 =
2697         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2698     const int16x8_t input2_val_s16 =
2699         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2700     const int16x8_t input1_val =
2701         vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
2702     const int16x8_t input2_val =
2703         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
2704     const int16x4_t input1_val_high = vget_high_s16(input1_val);
2705     const int16x4_t input1_val_low = vget_low_s16(input1_val);
2706     const int16x4_t input2_val_high = vget_high_s16(input2_val);
2707     const int16x4_t input2_val_low = vget_low_s16(input2_val);
2708     int32x4_t x11 = vmovl_s16(input1_val_low);
2709     int32x4_t x12 = vmovl_s16(input1_val_high);
2710     int32x4_t x21 = vmovl_s16(input2_val_low);
2711     int32x4_t x22 = vmovl_s16(input2_val_high);
2712     const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
2713     x11 = vshlq_s32(x11, left_shift_dup);
2714     x12 = vshlq_s32(x12, left_shift_dup);
2715     x21 = vshlq_s32(x21, left_shift_dup);
2716     x22 = vshlq_s32(x22, left_shift_dup);
2717     x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
2718     x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
2719     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2720     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2721     const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
2722     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2723     x11 = vshlq_s32(x11, input1_shift_dup);
2724     x12 = vshlq_s32(x12, input1_shift_dup);
2725     x21 = vshlq_s32(x21, input2_shift_dup);
2726     x22 = vshlq_s32(x22, input2_shift_dup);
2727     int32x4_t s1 = vaddq_s32(x11, x21);
2728     int32x4_t s2 = vaddq_s32(x12, x22);
2729     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2730     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2731     using gemmlowp::RoundingDivideByPOT;
2732     s1 = RoundingDivideByPOT(s1, -params.output_shift);
2733     s2 = RoundingDivideByPOT(s2, -params.output_shift);
2734     const int16x4_t s1_narrowed = vmovn_s32(s1);
2735     const int16x4_t s2_narrowed = vmovn_s32(s2);
2736     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2737                                   vdupq_n_s16(params.output_offset));
2738     const uint8x8_t clamped =
2739         vmax_u8(output_activation_min_vector,
2740                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2741     vst1_u8(output_data + i, clamped);
2742   }
2743 #endif  // NEON
2744 
2745   for (; i < size; ++i) {
2746     const int32 input1_val = params.input1_offset + input1_data[i];
2747     const int32 input2_val = params.input2_offset + input2_data[i];
2748     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2749     const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2750     const int32 scaled_input1_val =
2751         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2752             shifted_input1_val, params.input1_multiplier, params.input1_shift);
2753     const int32 scaled_input2_val =
2754         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2755             shifted_input2_val, params.input2_multiplier, params.input2_shift);
2756     const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2757     const int32 raw_output =
2758         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2759             raw_sum, params.output_multiplier, params.output_shift) +
2760         params.output_offset;
2761     const int32 clamped_output =
2762         std::min(params.quantized_activation_max,
2763                  std::max(params.quantized_activation_min, raw_output));
2764     output_data[i] = static_cast<uint8>(clamped_output);
2765   }
2766 }
2767 
2768 // Scalar-broadcast add that can be used for inner loop of more general
2769 // broadcast add, so that, for example, scalar-broadcast with batch will still
2770 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)2771 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
2772                                uint8 input1_data, const uint8* input2_data,
2773                                uint8* output_data) {
2774   using gemmlowp::RoundingDivideByPOT;
2775 
2776   gemmlowp::ScopedProfilingLabel label("AddScalarBroadcast/8bit");
2777   TFLITE_DCHECK_GT(params.input1_offset, -256);
2778   TFLITE_DCHECK_GT(params.input2_offset, -256);
2779   TFLITE_DCHECK_LT(params.input1_offset, 256);
2780   TFLITE_DCHECK_LT(params.input2_offset, 256);
2781 
2782   int i = 0;
2783 
2784 #ifdef USE_NEON
2785   const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
2786   const uint8x8_t output_activation_min_vector =
2787       vdup_n_u8(params.quantized_activation_min);
2788   const uint8x8_t output_activation_max_vector =
2789       vdup_n_u8(params.quantized_activation_max);
2790 
2791   // Process broadcast scalar.
2792   const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
2793   const int16x8_t input1_val_s16 =
2794       vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2795   const int16x8_t input1_val =
2796       vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
2797   const int16x4_t input1_val_high = vget_high_s16(input1_val);
2798   const int16x4_t input1_val_low = vget_low_s16(input1_val);
2799   int32x4_t x11 = vmovl_s16(input1_val_low);
2800   int32x4_t x12 = vmovl_s16(input1_val_high);
2801   x11 = vshlq_s32(x11, left_shift_dup);
2802   x12 = vshlq_s32(x12, left_shift_dup);
2803   x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
2804   x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
2805   const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
2806   x11 = vshlq_s32(x11, input1_shift_dup);
2807   x12 = vshlq_s32(x12, input1_shift_dup);
2808 
2809   for (; i <= size - 8; i += 8) {
2810     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
2811     const int16x8_t input2_val_s16 =
2812         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2813     const int16x8_t input2_val =
2814         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
2815     const int16x4_t input2_val_high = vget_high_s16(input2_val);
2816     const int16x4_t input2_val_low = vget_low_s16(input2_val);
2817     int32x4_t x21 = vmovl_s16(input2_val_low);
2818     int32x4_t x22 = vmovl_s16(input2_val_high);
2819     x21 = vshlq_s32(x21, left_shift_dup);
2820     x22 = vshlq_s32(x22, left_shift_dup);
2821     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2822     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2823     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2824     x21 = vshlq_s32(x21, input2_shift_dup);
2825     x22 = vshlq_s32(x22, input2_shift_dup);
2826     int32x4_t s1 = vaddq_s32(x11, x21);
2827     int32x4_t s2 = vaddq_s32(x12, x22);
2828     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2829     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2830     s1 = RoundingDivideByPOT(s1, -params.output_shift);
2831     s2 = RoundingDivideByPOT(s2, -params.output_shift);
2832     const int16x4_t s1_narrowed = vmovn_s32(s1);
2833     const int16x4_t s2_narrowed = vmovn_s32(s2);
2834     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2835                                   vdupq_n_s16(params.output_offset));
2836     const uint8x8_t clamped =
2837         vmax_u8(output_activation_min_vector,
2838                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2839     vst1_u8(output_data + i, clamped);
2840   }
2841 #endif  // NEON
2842 
2843   if (i < size) {
2844     // Process broadcast scalar.
2845     const int32 input1_val = params.input1_offset + input1_data;
2846     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2847     const int32 scaled_input1_val =
2848         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2849             shifted_input1_val, params.input1_multiplier, params.input1_shift);
2850 
2851     for (; i < size; ++i) {
2852       const int32 input2_val = params.input2_offset + input2_data[i];
2853       const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2854       const int32 scaled_input2_val =
2855           MultiplyByQuantizedMultiplierSmallerThanOneExp(
2856               shifted_input2_val, params.input2_multiplier,
2857               params.input2_shift);
2858       const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2859       const int32 raw_output =
2860           MultiplyByQuantizedMultiplierSmallerThanOneExp(
2861               raw_sum, params.output_multiplier, params.output_shift) +
2862           params.output_offset;
2863       const int32 clamped_output =
2864           std::min(params.quantized_activation_max,
2865                    std::max(params.quantized_activation_min, raw_output));
2866       output_data[i] = static_cast<uint8>(clamped_output);
2867     }
2868   }
2869 }
2870 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2871 inline void Add(const ArithmeticParams& params,
2872                 const RuntimeShape& input1_shape, const uint8* input1_data,
2873                 const RuntimeShape& input2_shape, const uint8* input2_data,
2874                 const RuntimeShape& output_shape, uint8* output_data) {
2875   TFLITE_DCHECK_LE(params.quantized_activation_min,
2876                    params.quantized_activation_max);
2877   gemmlowp::ScopedProfilingLabel label("Add/8bit");
2878   const int flat_size =
2879       MatchingFlatSize(input1_shape, input2_shape, output_shape);
2880 
2881   TFLITE_DCHECK_GT(params.input1_offset, -256);
2882   TFLITE_DCHECK_GT(params.input2_offset, -256);
2883   TFLITE_DCHECK_LT(params.input1_offset, 256);
2884   TFLITE_DCHECK_LT(params.input2_offset, 256);
2885   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
2886 }
2887 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2888 inline void Add(const ArithmeticParams& params,
2889                 const RuntimeShape& input1_shape, const int16* input1_data,
2890                 const RuntimeShape& input2_shape, const int16* input2_data,
2891                 const RuntimeShape& output_shape, int16* output_data) {
2892   gemmlowp::ScopedProfilingLabel label("Add/Int16");
2893   TFLITE_DCHECK_LE(params.quantized_activation_min,
2894                    params.quantized_activation_max);
2895 
2896   const int input1_shift = params.input1_shift;
2897   const int flat_size =
2898       MatchingFlatSize(output_shape, input1_shape, input2_shape);
2899   const int16 output_activation_min = params.quantized_activation_min;
2900   const int16 output_activation_max = params.quantized_activation_max;
2901 
2902   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
2903   TFLITE_DCHECK_LE(input1_shift, 0);
2904   TFLITE_DCHECK_LE(params.input2_shift, 0);
2905   const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
2906   const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
2907   const int input_right_shift =
2908       input1_shift == 0 ? -params.input2_shift : -input1_shift;
2909 
2910   for (int i = 0; i < flat_size; i++) {
2911     // F0 uses 0 integer bits, range [-1, 1].
2912     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2913 
2914     F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
2915     F0 scaled_input = F0::FromRaw(
2916         gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
2917     F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
2918     const int16 raw_output = result.raw();
2919     const int16 clamped_output = std::min(
2920         output_activation_max, std::max(output_activation_min, raw_output));
2921     output_data[i] = clamped_output;
2922   }
2923 }
2924 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2925 inline void Add(const ArithmeticParams& params,
2926                 const RuntimeShape& input1_shape, const int32* input1_data,
2927                 const RuntimeShape& input2_shape, const int32* input2_data,
2928                 const RuntimeShape& output_shape, int32* output_data) {
2929   gemmlowp::ScopedProfilingLabel label("Add/int32");
2930 
2931   auto input1_map = MapAsVector(input1_data, input1_shape);
2932   auto input2_map = MapAsVector(input2_data, input2_shape);
2933   auto output_map = MapAsVector(output_data, output_shape);
2934   if (input1_shape == input2_shape) {
2935     output_map.array() = input1_map.array() + input2_map.array();
2936   } else if (input2_shape.FlatSize() == 1) {
2937     auto scalar = input2_data[0];
2938     output_map.array() = input1_map.array() + scalar;
2939   } else if (input1_shape.FlatSize() == 1) {
2940     auto scalar = input1_data[0];
2941     output_map.array() = scalar + input2_map.array();
2942   } else {
2943     // Should not come here.
2944     TFLITE_DCHECK(false);
2945   }
2946   output_map = output_map.cwiseMax(params.quantized_activation_min);
2947   output_map = output_map.cwiseMin(params.quantized_activation_max);
2948 }
2949 
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2950 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
2951                                  const RuntimeShape& unswitched_input1_shape,
2952                                  const uint8* unswitched_input1_data,
2953                                  const RuntimeShape& unswitched_input2_shape,
2954                                  const uint8* unswitched_input2_data,
2955                                  const RuntimeShape& output_shape,
2956                                  uint8* output_data) {
2957   gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
2958 
2959   ArithmeticParams switched_params = unswitched_params;
2960   switched_params.input1_offset = unswitched_params.input2_offset;
2961   switched_params.input1_multiplier = unswitched_params.input2_multiplier;
2962   switched_params.input1_shift = unswitched_params.input2_shift;
2963   switched_params.input2_offset = unswitched_params.input1_offset;
2964   switched_params.input2_multiplier = unswitched_params.input1_multiplier;
2965   switched_params.input2_shift = unswitched_params.input1_shift;
2966 
2967   const bool use_unswitched =
2968       unswitched_params.broadcast_category ==
2969       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
2970 
2971   const ArithmeticParams& params =
2972       use_unswitched ? unswitched_params : switched_params;
2973   const uint8* input1_data =
2974       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
2975   const uint8* input2_data =
2976       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
2977 
2978   // Fivefold nested loops. The second input resets its position for each
2979   // iteration of the second loop. The first input resets its position at the
2980   // beginning of the fourth loop. The innermost loop is an elementwise add of
2981   // sections of the arrays.
2982   uint8* output_data_ptr = output_data;
2983   const uint8* input1_data_ptr = input1_data;
2984   const uint8* input2_data_reset = input2_data;
2985   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
2986   // between input shapes. y3 for input 1 is always broadcast, and so the
2987   // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
2988   // Put another way,
2989   // input1.shape.FlatSize = y0 * y1 * y2 * y4,
2990   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
2991   int y0 = params.broadcast_shape[0];
2992   int y1 = params.broadcast_shape[1];
2993   int y2 = params.broadcast_shape[2];
2994   int y3 = params.broadcast_shape[3];
2995   int y4 = params.broadcast_shape[4];
2996   if (y4 > 1) {
2997     // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
2998     // dimension.
2999     for (int i0 = 0; i0 < y0; ++i0) {
3000       const uint8* input2_data_ptr = nullptr;
3001       for (int i1 = 0; i1 < y1; ++i1) {
3002         input2_data_ptr = input2_data_reset;
3003         for (int i2 = 0; i2 < y2; ++i2) {
3004           for (int i3 = 0; i3 < y3; ++i3) {
3005             AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
3006                            output_data_ptr);
3007             input2_data_ptr += y4;
3008             output_data_ptr += y4;
3009           }
3010           // We have broadcast y4 of input1 data y3 times, and now move on.
3011           input1_data_ptr += y4;
3012         }
3013       }
3014       // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
3015       input2_data_reset = input2_data_ptr;
3016     }
3017   } else {
3018     // Special case of y4 == 1, in which the innermost loop is a single element
3019     // and can be combined with the next (y3) as an inner broadcast.
3020     //
3021     // Note that this handles the case of pure scalar broadcast when
3022     // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
3023     // broadcast with batch (as y2 > 1).
3024     //
3025     // NOTE The process is the same as the above general case except simplified
3026     // for y4 == 1 and the loop over y3 is contained within the
3027     // AddScalarBroadcast function.
3028     for (int i0 = 0; i0 < y0; ++i0) {
3029       const uint8* input2_data_ptr = nullptr;
3030       for (int i1 = 0; i1 < y1; ++i1) {
3031         input2_data_ptr = input2_data_reset;
3032         for (int i2 = 0; i2 < y2; ++i2) {
3033           AddScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
3034                              output_data_ptr);
3035           input2_data_ptr += y3;
3036           output_data_ptr += y3;
3037           input1_data_ptr += 1;
3038         }
3039       }
3040       input2_data_reset = input2_data_ptr;
3041     }
3042   }
3043 }
3044 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)3045 inline void Mul(const ArithmeticParams& params,
3046                 const RuntimeShape& input1_shape, const float* input1_data,
3047                 const RuntimeShape& input2_shape, const float* input2_data,
3048                 const RuntimeShape& output_shape, float* output_data) {
3049   gemmlowp::ScopedProfilingLabel label("Mul");
3050   const float output_activation_min = params.float_activation_min;
3051   const float output_activation_max = params.float_activation_max;
3052 
3053   int i = 0;
3054   const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
3055 #ifdef USE_NEON
3056   const auto activation_min = vdupq_n_f32(output_activation_min);
3057   const auto activation_max = vdupq_n_f32(output_activation_max);
3058   for (; i <= size - 16; i += 16) {
3059     auto a10 = vld1q_f32(input1_data + i);
3060     auto a11 = vld1q_f32(input1_data + i + 4);
3061     auto a12 = vld1q_f32(input1_data + i + 8);
3062     auto a13 = vld1q_f32(input1_data + i + 12);
3063     auto a20 = vld1q_f32(input2_data + i);
3064     auto a21 = vld1q_f32(input2_data + i + 4);
3065     auto a22 = vld1q_f32(input2_data + i + 8);
3066     auto a23 = vld1q_f32(input2_data + i + 12);
3067     auto x0 = vmulq_f32(a10, a20);
3068     auto x1 = vmulq_f32(a11, a21);
3069     auto x2 = vmulq_f32(a12, a22);
3070     auto x3 = vmulq_f32(a13, a23);
3071 
3072     x0 = vmaxq_f32(activation_min, x0);
3073     x1 = vmaxq_f32(activation_min, x1);
3074     x2 = vmaxq_f32(activation_min, x2);
3075     x3 = vmaxq_f32(activation_min, x3);
3076     x0 = vminq_f32(activation_max, x0);
3077     x1 = vminq_f32(activation_max, x1);
3078     x2 = vminq_f32(activation_max, x2);
3079     x3 = vminq_f32(activation_max, x3);
3080 
3081     vst1q_f32(output_data + i, x0);
3082     vst1q_f32(output_data + i + 4, x1);
3083     vst1q_f32(output_data + i + 8, x2);
3084     vst1q_f32(output_data + i + 12, x3);
3085   }
3086   for (; i <= size - 4; i += 4) {
3087     auto a1 = vld1q_f32(input1_data + i);
3088     auto a2 = vld1q_f32(input2_data + i);
3089     auto x = vmulq_f32(a1, a2);
3090 
3091     x = vmaxq_f32(activation_min, x);
3092     x = vminq_f32(activation_max, x);
3093 
3094     vst1q_f32(output_data + i, x);
3095   }
3096 #endif  // NEON
3097 
3098   for (; i < size; i++) {
3099     auto x = input1_data[i] * input2_data[i];
3100     output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
3101                                                   output_activation_max);
3102   }
3103 }
3104 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)3105 inline void Mul(const ArithmeticParams& params,
3106                 const RuntimeShape& input1_shape, const int32* input1_data,
3107                 const RuntimeShape& input2_shape, const int32* input2_data,
3108                 const RuntimeShape& output_shape, int32* output_data) {
3109   gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
3110 
3111   const int flat_size =
3112       MatchingFlatSize(input1_shape, input2_shape, output_shape);
3113   const int32 output_activation_min = params.quantized_activation_min;
3114   const int32 output_activation_max = params.quantized_activation_max;
3115   for (int i = 0; i < flat_size; ++i) {
3116     output_data[i] = ActivationFunctionWithMinMax(
3117         input1_data[i] * input2_data[i], output_activation_min,
3118         output_activation_max);
3119   }
3120 }
3121 
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)3122 inline void MulNoActivation(const ArithmeticParams& params,
3123                             const RuntimeShape& input1_shape,
3124                             const int32* input1_data,
3125                             const RuntimeShape& input2_shape,
3126                             const int32* input2_data,
3127                             const RuntimeShape& output_shape,
3128                             int32* output_data) {
3129   gemmlowp::ScopedProfilingLabel label("Mul/int32");
3130 
3131   auto input1_map = MapAsVector(input1_data, input1_shape);
3132   auto input2_map = MapAsVector(input2_data, input2_shape);
3133   auto output_map = MapAsVector(output_data, output_shape);
3134   if (input1_shape == input2_shape) {
3135     output_map.array() = input1_map.array() * input2_map.array();
3136   } else if (input2_shape.FlatSize() == 1) {
3137     auto scalar = input2_data[0];
3138     output_map.array() = input1_map.array() * scalar;
3139   } else if (input1_shape.FlatSize() == 1) {
3140     auto scalar = input1_data[0];
3141     output_map.array() = scalar * input2_map.array();
3142   } else {
3143     // Should not come here.
3144     TFLITE_DCHECK(false);
3145   }
3146 }
3147 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)3148 inline void Mul(const ArithmeticParams& params,
3149                 const RuntimeShape& input1_shape, const int16* input1_data,
3150                 const RuntimeShape& input2_shape, const int16* input2_data,
3151                 const RuntimeShape& output_shape, int16* output_data) {
3152   gemmlowp::ScopedProfilingLabel label("Mul/Int16/NoActivation");
3153   // This is a copy of the reference implementation. We do not currently have a
3154   // properly optimized version.
3155 
3156   const int flat_size =
3157       MatchingFlatSize(input1_shape, input2_shape, output_shape);
3158 
3159   for (int i = 0; i < flat_size; i++) {
3160     // F0 uses 0 integer bits, range [-1, 1].
3161     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3162 
3163     F0 unclamped_result =
3164         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
3165     output_data[i] = unclamped_result.raw();
3166   }
3167 }
3168 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)3169 inline void Mul(const ArithmeticParams& params,
3170                 const RuntimeShape& input1_shape, const int16* input1_data,
3171                 const RuntimeShape& input2_shape, const int16* input2_data,
3172                 const RuntimeShape& output_shape, uint8* output_data) {
3173   gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
3174   // This is a copy of the reference implementation. We do not currently have a
3175   // properly optimized version.
3176   const int32 output_activation_min = params.quantized_activation_min;
3177   const int32 output_activation_max = params.quantized_activation_max;
3178   const int32 output_offset = params.output_offset;
3179   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3180 
3181   const int flat_size =
3182       MatchingFlatSize(input1_shape, input2_shape, output_shape);
3183 
3184   for (int i = 0; i < flat_size; i++) {
3185     // F0 uses 0 integer bits, range [-1, 1].
3186     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3187 
3188     F0 unclamped_result =
3189         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
3190     int16 rescaled_result =
3191         gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
3192     int16 clamped_result =
3193         std::min<int16>(output_activation_max - output_offset, rescaled_result);
3194     clamped_result =
3195         std::max<int16>(output_activation_min - output_offset, clamped_result);
3196     output_data[i] = output_offset + clamped_result;
3197   }
3198 }
3199 
3200 // Element-wise mul that can often be used for inner loop of broadcast Mul as
3201 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)3202 inline void MulElementwise(int size, const ArithmeticParams& params,
3203                            const uint8* input1_data, const uint8* input2_data,
3204                            uint8* output_data) {
3205   int i = 0;
3206   TFLITE_DCHECK_GT(params.input1_offset, -256);
3207   TFLITE_DCHECK_LT(params.input1_offset, 256);
3208   TFLITE_DCHECK_GT(params.input2_offset, -256);
3209   TFLITE_DCHECK_LT(params.input2_offset, 256);
3210   TFLITE_DCHECK_GT(params.output_offset, -256);
3211   TFLITE_DCHECK_LT(params.output_offset, 256);
3212 #ifdef USE_NEON
3213   const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
3214   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
3215   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
3216   const auto output_activation_min_vector =
3217       vdup_n_u8(params.quantized_activation_min);
3218   const auto output_activation_max_vector =
3219       vdup_n_u8(params.quantized_activation_max);
3220   for (; i <= size - 8; i += 8) {
3221     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
3222     const auto input1_val_original = vld1_u8(input1_data + i);
3223     const auto input2_val_original = vld1_u8(input2_data + i);
3224     const auto input1_val_s16 =
3225         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
3226     const auto input2_val_s16 =
3227         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
3228     const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
3229     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
3230 
3231     const auto input1_val_low = vget_low_s16(input1_val);
3232     const auto input1_val_high = vget_high_s16(input1_val);
3233     const auto input2_val_low = vget_low_s16(input2_val);
3234     const auto input2_val_high = vget_high_s16(input2_val);
3235 
3236     auto p1 = vmull_s16(input2_val_low, input1_val_low);
3237     auto p2 = vmull_s16(input2_val_high, input1_val_high);
3238 
3239     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
3240     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
3241     using gemmlowp::RoundingDivideByPOT;
3242     p1 = RoundingDivideByPOT(p1, -params.output_shift);
3243     p2 = RoundingDivideByPOT(p2, -params.output_shift);
3244 
3245     const auto p1_narrowed = vmovn_s32(p1);
3246     const auto p2_narrowed = vmovn_s32(p2);
3247     const auto p =
3248         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
3249     const auto clamped =
3250         vmax_u8(output_activation_min_vector,
3251                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
3252     vst1_u8(output_data + i, clamped);
3253   }
3254 #endif  // NEON
3255 
3256   for (; i < size; ++i) {
3257     const int32 input1_val = params.input1_offset + input1_data[i];
3258     const int32 input2_val = params.input2_offset + input2_data[i];
3259     const int32 unclamped_result =
3260         params.output_offset +
3261         MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
3262                                                        params.output_multiplier,
3263                                                        params.output_shift);
3264     const int32 clamped_output =
3265         std::min(params.quantized_activation_max,
3266                  std::max(params.quantized_activation_min, unclamped_result));
3267     output_data[i] = static_cast<uint8>(clamped_output);
3268   }
3269 }
3270 
3271 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)3272 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
3273                                const uint8 broadcast_value,
3274                                const uint8* input2_data, uint8* output_data) {
3275   const int16 input1_val = params.input1_offset + broadcast_value;
3276 
3277   int i = 0;
3278   TFLITE_DCHECK_GT(params.input1_offset, -256);
3279   TFLITE_DCHECK_LT(params.input1_offset, 256);
3280   TFLITE_DCHECK_GT(params.input2_offset, -256);
3281   TFLITE_DCHECK_LT(params.input2_offset, 256);
3282   TFLITE_DCHECK_GT(params.output_offset, -256);
3283   TFLITE_DCHECK_LT(params.output_offset, 256);
3284 #ifdef USE_NEON
3285   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
3286   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
3287   const auto output_activation_min_vector =
3288       vdup_n_u8(params.quantized_activation_min);
3289   const auto output_activation_max_vector =
3290       vdup_n_u8(params.quantized_activation_max);
3291   for (; i <= size - 8; i += 8) {
3292     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
3293     const auto input2_val_original = vld1_u8(input2_data + i);
3294     const auto input2_val_s16 =
3295         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
3296     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
3297 
3298     const auto input2_val_low = vget_low_s16(input2_val);
3299     const auto input2_val_high = vget_high_s16(input2_val);
3300 
3301     auto p1 = vmull_n_s16(input2_val_low, input1_val);
3302     auto p2 = vmull_n_s16(input2_val_high, input1_val);
3303 
3304     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
3305     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
3306     using gemmlowp::RoundingDivideByPOT;
3307     p1 = RoundingDivideByPOT(p1, -params.output_shift);
3308     p2 = RoundingDivideByPOT(p2, -params.output_shift);
3309 
3310     const auto p1_narrowed = vmovn_s32(p1);
3311     const auto p2_narrowed = vmovn_s32(p2);
3312     const auto p =
3313         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
3314     const auto clamped =
3315         vmax_u8(output_activation_min_vector,
3316                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
3317     vst1_u8(output_data + i, clamped);
3318   }
3319 #endif  // NEON
3320 
3321   for (; i < size; ++i) {
3322     const int32 input2_val = params.input2_offset + input2_data[i];
3323     const int32 unclamped_result =
3324         params.output_offset +
3325         MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
3326                                                        params.output_multiplier,
3327                                                        params.output_shift);
3328     const int32 clamped_output =
3329         std::min(params.quantized_activation_max,
3330                  std::max(params.quantized_activation_min, unclamped_result));
3331     output_data[i] = static_cast<uint8>(clamped_output);
3332   }
3333 }
3334 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)3335 inline void Mul(const ArithmeticParams& params,
3336                 const RuntimeShape& input1_shape, const uint8* input1_data,
3337                 const RuntimeShape& input2_shape, const uint8* input2_data,
3338                 const RuntimeShape& output_shape, uint8* output_data) {
3339   TFLITE_DCHECK_LE(params.quantized_activation_min,
3340                    params.quantized_activation_max);
3341   gemmlowp::ScopedProfilingLabel label("Mul/8bit");
3342   const int flat_size =
3343       MatchingFlatSize(input1_shape, input2_shape, output_shape);
3344 
3345   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
3346 }
3347 
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)3348 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
3349                                  const RuntimeShape& unswitched_input1_shape,
3350                                  const uint8* unswitched_input1_data,
3351                                  const RuntimeShape& unswitched_input2_shape,
3352                                  const uint8* unswitched_input2_data,
3353                                  const RuntimeShape& output_shape,
3354                                  uint8* output_data) {
3355   gemmlowp::ScopedProfilingLabel label("BroadcastMulFivefold/8bit");
3356 
3357   ArithmeticParams switched_params = unswitched_params;
3358   switched_params.input1_offset = unswitched_params.input2_offset;
3359   switched_params.input2_offset = unswitched_params.input1_offset;
3360 
3361   const bool use_unswitched =
3362       unswitched_params.broadcast_category ==
3363       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
3364 
3365   const ArithmeticParams& params =
3366       use_unswitched ? unswitched_params : switched_params;
3367   const uint8* input1_data =
3368       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
3369   const uint8* input2_data =
3370       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
3371 
3372   // Fivefold nested loops. The second input resets its position for each
3373   // iteration of the second loop. The first input resets its position at the
3374   // beginning of the fourth loop. The innermost loop is an elementwise Mul of
3375   // sections of the arrays.
3376   uint8* output_data_ptr = output_data;
3377   const uint8* input1_data_ptr = input1_data;
3378   const uint8* input2_data_reset = input2_data;
3379   int y0 = params.broadcast_shape[0];
3380   int y1 = params.broadcast_shape[1];
3381   int y2 = params.broadcast_shape[2];
3382   int y3 = params.broadcast_shape[3];
3383   int y4 = params.broadcast_shape[4];
3384   if (y4 > 1) {
3385     for (int i0 = 0; i0 < y0; ++i0) {
3386       const uint8* input2_data_ptr = nullptr;
3387       for (int i1 = 0; i1 < y1; ++i1) {
3388         input2_data_ptr = input2_data_reset;
3389         for (int i2 = 0; i2 < y2; ++i2) {
3390           for (int i3 = 0; i3 < y3; ++i3) {
3391             MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
3392                            output_data_ptr);
3393             input2_data_ptr += y4;
3394             output_data_ptr += y4;
3395           }
3396           input1_data_ptr += y4;
3397         }
3398       }
3399       input2_data_reset = input2_data_ptr;
3400     }
3401   } else {
3402     for (int i0 = 0; i0 < y0; ++i0) {
3403       const uint8* input2_data_ptr = nullptr;
3404       for (int i1 = 0; i1 < y1; ++i1) {
3405         input2_data_ptr = input2_data_reset;
3406         for (int i2 = 0; i2 < y2; ++i2) {
3407           MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
3408                              output_data_ptr);
3409           input2_data_ptr += y3;
3410           output_data_ptr += y3;
3411           ++input1_data_ptr;
3412         }
3413       }
3414       input2_data_reset = input2_data_ptr;
3415     }
3416   }
3417 }
3418 
3419 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
3420 // dimensionality if the runtime code does a single loop over one dimension
3421 // that handles broadcasting as the base case. The code generator would then
3422 // generate max(D1, D2) nested for loops.
3423 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
3424 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
3425 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
3426 // reference_ops.h.
3427 template <typename T>
BroadcastDiv4DSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)3428 void BroadcastDiv4DSlow(const ArithmeticParams& params,
3429                         const RuntimeShape& unextended_input1_shape,
3430                         const T* input1_data,
3431                         const RuntimeShape& unextended_input2_shape,
3432                         const T* input2_data,
3433                         const RuntimeShape& unextended_output_shape,
3434                         T* output_data) {
3435   gemmlowp::ScopedProfilingLabel label("BroadcastDiv4DSlow");
3436   T output_activation_min;
3437   T output_activation_max;
3438   GetActivationParams(params, &output_activation_min, &output_activation_max);
3439 
3440   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
3441   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
3442   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
3443   const RuntimeShape output_shape =
3444       RuntimeShape::ExtendedShape(4, unextended_output_shape);
3445 
3446   NdArrayDesc<4> desc1;
3447   NdArrayDesc<4> desc2;
3448   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
3449                                       unextended_input2_shape, &desc1, &desc2);
3450 
3451   // In Tensorflow, the dimensions are canonically named (batch_number, row,
3452   // col, channel), with extents (batches, height, width, depth), with the
3453   // trailing dimension changing most rapidly (channels has the smallest stride,
3454   // typically 1 element).
3455   //
3456   // In generated C code, we store arrays with the dimensions reversed. The
3457   // first dimension has smallest stride.
3458   //
3459   // We name our variables by their Tensorflow convention, but generate C code
3460   // nesting loops such that the innermost loop has the smallest stride for the
3461   // best cache behavior.
3462   for (int b = 0; b < output_shape.Dims(0); ++b) {
3463     for (int y = 0; y < output_shape.Dims(1); ++y) {
3464       for (int x = 0; x < output_shape.Dims(2); ++x) {
3465         for (int c = 0; c < output_shape.Dims(3); ++c) {
3466           output_data[Offset(output_shape, b, y, x, c)] =
3467               ActivationFunctionWithMinMax(
3468                   input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
3469                       input2_data[SubscriptToIndex(desc2, b, y, x, c)],
3470                   output_activation_min, output_activation_max);
3471         }
3472       }
3473     }
3474   }
3475 }
3476 
3477 // TODO(aselle): This is not actually optimized yet.
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)3478 inline void SubNonBroadcast(const ArithmeticParams& params,
3479                             const RuntimeShape& input1_shape,
3480                             const float* input1_data,
3481                             const RuntimeShape& input2_shape,
3482                             const float* input2_data,
3483                             const RuntimeShape& output_shape,
3484                             float* output_data) {
3485   gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
3486   const int flat_size =
3487       MatchingFlatSize(input1_shape, input2_shape, output_shape);
3488   for (int i = 0; i < flat_size; ++i) {
3489     output_data[i] = ActivationFunctionWithMinMax(
3490         input1_data[i] - input2_data[i], params.float_activation_min,
3491         params.float_activation_max);
3492   }
3493 }
3494 
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)3495 inline void SubWithActivation(const ArithmeticParams& params,
3496                               const RuntimeShape& input1_shape,
3497                               const int32* input1_data,
3498                               const RuntimeShape& input2_shape,
3499                               const int32* input2_data,
3500                               const RuntimeShape& output_shape,
3501                               int32* output_data) {
3502   gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
3503   const int flat_size =
3504       MatchingFlatSize(input1_shape, input2_shape, input2_shape);
3505   for (int i = 0; i < flat_size; ++i) {
3506     output_data[i] = ActivationFunctionWithMinMax(
3507         input1_data[i] - input2_data[i], params.quantized_activation_min,
3508         params.quantized_activation_max);
3509   }
3510 }
3511 
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)3512 inline void SubWithActivation(const ArithmeticParams& params,
3513                               const RuntimeShape& input1_shape,
3514                               const float* input1_data,
3515                               const RuntimeShape& input2_shape,
3516                               const float* input2_data,
3517                               const RuntimeShape& output_shape,
3518                               float* output_data) {
3519   gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
3520   const int flat_size =
3521       MatchingFlatSize(input1_shape, input2_shape, input2_shape);
3522   for (int i = 0; i < flat_size; ++i) {
3523     output_data[i] = ActivationFunctionWithMinMax(
3524         input1_data[i] - input2_data[i], params.float_activation_min,
3525         params.float_activation_max);
3526   }
3527 }
3528 
3529 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)3530 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
3531          const T* input1_data, const RuntimeShape& input2_shape,
3532          const T* input2_data, const RuntimeShape& output_shape,
3533          T* output_data) {
3534   gemmlowp::ScopedProfilingLabel label("Sub");
3535 
3536   auto input1_map = MapAsVector(input1_data, input1_shape);
3537   auto input2_map = MapAsVector(input2_data, input2_shape);
3538   auto output_map = MapAsVector(output_data, output_shape);
3539   if (input1_shape == input2_shape) {
3540     output_map.array() = input1_map.array() - input2_map.array();
3541   } else if (input1_shape.FlatSize() == 1) {
3542     auto scalar = input1_data[0];
3543     output_map.array() = scalar - input2_map.array();
3544   } else if (input2_shape.FlatSize() == 1) {
3545     auto scalar = input2_data[0];
3546     output_map.array() = input1_map.array() - scalar;
3547   } else {
3548     BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape,
3549                        input2_data, output_shape, output_data);
3550   }
3551 }
3552 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)3553 inline void LstmCell(
3554     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3555     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
3556     const float* prev_activ_data, const RuntimeShape& weights_shape,
3557     const float* weights_data, const RuntimeShape& unextended_bias_shape,
3558     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
3559     const float* prev_state_data,
3560     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
3561     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
3562     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
3563     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
3564   gemmlowp::ScopedProfilingLabel label("LstmCell");
3565   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3566   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3567   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3568   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3569   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3570   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3571   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3572   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3573   const RuntimeShape input_shape =
3574       RuntimeShape::ExtendedShape(4, unextended_input_shape);
3575   const RuntimeShape prev_activ_shape =
3576       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3577   const RuntimeShape bias_shape =
3578       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3579   const RuntimeShape prev_state_shape =
3580       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3581   const RuntimeShape output_state_shape =
3582       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3583   const RuntimeShape output_activ_shape =
3584       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3585   const RuntimeShape concat_temp_shape =
3586       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3587   const RuntimeShape activ_temp_shape =
3588       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3589   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3590 
3591   const int weights_dim_count = weights_shape.DimensionsCount();
3592   MatchingDim(  // batches
3593       input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
3594       output_state_shape, 0, output_activ_shape, 0);
3595   MatchingDim(  // height
3596       input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
3597       output_state_shape, 1, output_activ_shape, 1);
3598   MatchingDim(  // width
3599       input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
3600       output_state_shape, 2, output_activ_shape, 2);
3601   const int input_depth = input_shape.Dims(3);
3602   const int prev_activ_depth = prev_activ_shape.Dims(3);
3603   const int total_input_depth = prev_activ_depth + input_depth;
3604   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3605                    total_input_depth);
3606   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3607   const int intern_activ_depth =
3608       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3609   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3610                    intern_activ_depth * total_input_depth);
3611   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3612   const int output_depth =
3613       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3614                   3, output_activ_shape, 3);
3615   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3616 
3617   // Concatenate prev_activ and input data together
3618   std::vector<float const*> concat_input_arrays_data;
3619   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
3620   concat_input_arrays_data.push_back(input_data);
3621   concat_input_arrays_data.push_back(prev_activ_data);
3622   concat_input_arrays_shapes.push_back(&input_shape);
3623   concat_input_arrays_shapes.push_back(&prev_activ_shape);
3624   tflite::ConcatenationParams concat_params;
3625   concat_params.axis = 3;
3626   concat_params.inputs_count = concat_input_arrays_data.size();
3627   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
3628                 &(concat_input_arrays_data[0]), concat_temp_shape,
3629                 concat_temp_data);
3630 
3631   // Fully connected
3632   tflite::FullyConnectedParams fc_params;
3633   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
3634   fc_params.float_activation_max = std::numeric_limits<float>::max();
3635   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
3636                  weights_data, bias_shape, bias_data, activ_temp_shape,
3637                  activ_temp_data);
3638 
3639   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
3640   // operations.
3641   ArrayMap<float> activ_temp_map =
3642       MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
3643   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
3644                                             activ_temp_map.cols());
3645   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
3646                                            activ_temp_map.cols());
3647   auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
3648                                              activ_temp_map.cols());
3649   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
3650                                              activ_temp_map.cols());
3651   ArrayMap<const float> prev_state_map =
3652       MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
3653   ArrayMap<float> output_state_map =
3654       MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
3655   ArrayMap<float> output_activ_map =
3656       MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
3657 
3658   // Combined memory state and final output calculation
3659   gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
3660   output_state_map =
3661       input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3662           new_input_sm.tanh() +
3663       forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3664           prev_state_map;
3665   output_activ_map =
3666       output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3667       output_state_map.tanh();
3668 }
3669 
3670 // Quantized LSTM cell. Currently just a copy of the reference impl in
3671 // reference_ops.h. See the big function comment there, not replicating it
3672 // here.
3673 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,gemmlowp::GemmContext * gemm_context)3674 inline void LstmCell(
3675     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3676     const uint8* input_data_uint8,
3677     const RuntimeShape& unextended_prev_activ_shape,
3678     const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
3679     const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
3680     const int32* bias_data_int32,
3681     const RuntimeShape& unextended_prev_state_shape,
3682     const int16* prev_state_data_int16,
3683     const RuntimeShape& unextended_output_state_shape,
3684     int16* output_state_data_int16,
3685     const RuntimeShape& unextended_output_activ_shape,
3686     uint8* output_activ_data_uint8,
3687     const RuntimeShape& unextended_concat_temp_shape,
3688     uint8* concat_temp_data_uint8,
3689     const RuntimeShape& unextended_activ_temp_shape,
3690     int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
3691   gemmlowp::ScopedProfilingLabel label(
3692       "LstmCell/quantized (8bit external, 16bit internal)");
3693   int32 weights_zero_point = params.weights_zero_point;
3694   int32 accum_multiplier = params.accum_multiplier;
3695   int accum_shift = params.accum_shift;
3696   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3697   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3698   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3699   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3700   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3701   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3702   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3703   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3704   const RuntimeShape input_shape =
3705       RuntimeShape::ExtendedShape(4, unextended_input_shape);
3706   const RuntimeShape prev_activ_shape =
3707       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3708   const RuntimeShape bias_shape =
3709       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3710   const RuntimeShape prev_state_shape =
3711       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3712   const RuntimeShape output_state_shape =
3713       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3714   const RuntimeShape output_activ_shape =
3715       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3716   const RuntimeShape concat_temp_shape =
3717       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3718   const RuntimeShape activ_temp_shape =
3719       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3720   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3721 
3722   // Gather dimensions information, and perform consistency checks.
3723   const int weights_dim_count = weights_shape.DimensionsCount();
3724   const int outer_size = MatchingFlatSizeSkipDim(
3725       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
3726       output_activ_shape);
3727   const int input_depth = input_shape.Dims(3);
3728   const int prev_activ_depth = prev_activ_shape.Dims(3);
3729   const int total_input_depth = prev_activ_depth + input_depth;
3730   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3731                    total_input_depth);
3732   const int intern_activ_depth =
3733       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3734   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3735                    intern_activ_depth * total_input_depth);
3736   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3737   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3738   const int output_depth =
3739       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3740                   3, output_activ_shape, 3);
3741   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3742   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
3743   const int fc_output_depth =
3744       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
3745   const int fc_accum_depth = total_input_depth;
3746   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
3747 
3748   // Depth-concatenate prev_activ and input data together.
3749   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
3750                                               prev_activ_data_uint8};
3751   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
3752                                                        &prev_activ_shape};
3753   tflite::ConcatenationParams concat_params;
3754   concat_params.axis = 3;
3755   concat_params.inputs_count = 2;
3756   Concatenation(concat_params, concat_input_arrays_shapes,
3757                 concat_input_arrays_data, concat_temp_shape,
3758                 concat_temp_data_uint8);
3759 
3760   // Implementation of the fully connected node inside the LSTM cell.
3761   // The operands are 8-bit integers, the accumulators are internally 32bit
3762   // integers, and the output is 16-bit fixed-point with 3 integer bits so
3763   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
3764   // is explained in the function comment above.
3765   bool gemm_already_performed = false;
3766 #ifdef GEMMLOWP_NEON
3767   if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
3768     GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
3769                     weights_data_uint8, weights_zero_point, bias_shape,
3770                     bias_data_int32, accum_multiplier, accum_shift,
3771                     activ_temp_shape, activ_temp_data_int16);
3772     gemm_already_performed = true;
3773   }
3774 #endif
3775   if (!gemm_already_performed) {
3776     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor>
3777         weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth);
3778     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
3779         concat_temp_data_uint8, fc_accum_depth, fc_batches);
3780     gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
3781         activ_temp_data_int16, fc_output_depth, fc_batches);
3782     typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
3783         ColVectorMap;
3784     ColVectorMap bias_vector(bias_data_int32, fc_output_depth);
3785     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
3786     bias_addition_stage.bias_vector = bias_vector;
3787     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
3788     scale_stage.result_offset_after_shift = 0;
3789     scale_stage.result_fixedpoint_multiplier = accum_multiplier;
3790     scale_stage.result_exponent = accum_shift;
3791     gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
3792     auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
3793                                            saturating_cast_int16_stage);
3794     gemmlowp::GemmWithOutputPipeline<
3795         uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
3796         gemm_context, weights_matrix, input_matrix, &output_matrix,
3797         -weights_zero_point, -128, output_pipeline);
3798   }
3799 
3800   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3801   // and muls, all done in 16-bit fixed-point.
3802   const int16* input_gate_input_ptr = activ_temp_data_int16;
3803   const int16* input_modulation_gate_input_ptr =
3804       activ_temp_data_int16 + output_depth;
3805   const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3806   const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3807   const int16* prev_state_ptr = prev_state_data_int16;
3808   int16* output_state_data_ptr = output_state_data_int16;
3809   uint8* output_activ_data_ptr = output_activ_data_uint8;
3810 
3811   for (int b = 0; b < outer_size; ++b) {
3812     int c = 0;
3813 #ifdef GEMMLOWP_NEON
3814     for (; c <= output_depth - 8; c += 8) {
3815       // Define the fixed-point data types that we will use here. All use
3816       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3817       // They only differ by the number of integral vs. fractional bits,
3818       // determining the range of values that they can represent.
3819       //
3820       // F0 uses 0 integer bits, range [-1, 1].
3821       // This is the return type of math functions such as tanh, logistic,
3822       // whose range is in [-1, 1].
3823       using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3824       // F3 uses 3 integer bits, range [-8, 8].
3825       // This is the range of the previous fully-connected node's output,
3826       // which is our input here.
3827       using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3828       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3829       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3830       // number of integer bits is currently dictated by the model. See comment
3831       // on the StateIntegerBits template parameter above.
3832       using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3833       // Implementation of input gate, using fixed-point logistic function.
3834       F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3835       input_gate_input_ptr += 8;
3836       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3837       // Implementation of input modulation gate, using fixed-point tanh
3838       // function.
3839       F3 input_modulation_gate_input =
3840           F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3841       input_modulation_gate_input_ptr += 8;
3842       F0 input_modulation_gate_output =
3843           gemmlowp::tanh(input_modulation_gate_input);
3844       // Implementation of forget gate, using fixed-point logistic function.
3845       F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3846       forget_gate_input_ptr += 8;
3847       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3848       // Implementation of output gate, using fixed-point logistic function.
3849       F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3850       output_gate_input_ptr += 8;
3851       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3852       // Implementation of internal multiplication nodes, still in fixed-point.
3853       F0 input_times_input_modulation =
3854           input_gate_output * input_modulation_gate_output;
3855       FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3856       prev_state_ptr += 8;
3857       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3858       // Implementation of internal addition node, saturating.
3859       FS new_state = gemmlowp::SaturatingAdd(
3860           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3861           prev_state_times_forget_state);
3862       // Implementation of last internal Tanh node, still in fixed-point.
3863       // Since a Tanh fixed-point implementation is specialized for a given
3864       // number or integer bits, and each specialization can have a substantial
3865       // code size, and we already used above a Tanh on an input with 3 integer
3866       // bits, and per the table in the above function comment there is no
3867       // significant accuracy to be lost by clamping to [-8, +8] for a
3868       // 3-integer-bits representation, let us just do that. This helps people
3869       // porting this to targets where code footprint must be minimized.
3870       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3871       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3872       // Store the new internal state back to memory, as 16-bit integers.
3873       // Note: here we store the original value with StateIntegerBits, not
3874       // the rescaled 3-integer-bits value fed to tanh.
3875       vst1q_s16(output_state_data_ptr, new_state.raw());
3876       output_state_data_ptr += 8;
3877       // Down-scale the output activations to 8-bit integers, saturating,
3878       // and store back to memory.
3879       int16x8_t rescaled_output_activ =
3880           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3881       int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3882       uint8x8_t uint8_output_activ =
3883           vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3884       vst1_u8(output_activ_data_ptr, uint8_output_activ);
3885       output_activ_data_ptr += 8;
3886     }
3887 #endif
3888     for (; c < output_depth; ++c) {
3889       // Define the fixed-point data types that we will use here. All use
3890       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3891       // They only differ by the number of integral vs. fractional bits,
3892       // determining the range of values that they can represent.
3893       //
3894       // F0 uses 0 integer bits, range [-1, 1].
3895       // This is the return type of math functions such as tanh, logistic,
3896       // whose range is in [-1, 1].
3897       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3898       // F3 uses 3 integer bits, range [-8, 8].
3899       // This is the range of the previous fully-connected node's output,
3900       // which is our input here.
3901       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3902       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3903       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3904       // number of integer bits is currently dictated by the model. See comment
3905       // on the StateIntegerBits template parameter above.
3906       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3907       // Implementation of input gate, using fixed-point logistic function.
3908       F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3909       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3910       // Implementation of input modulation gate, using fixed-point tanh
3911       // function.
3912       F3 input_modulation_gate_input =
3913           F3::FromRaw(*input_modulation_gate_input_ptr++);
3914       F0 input_modulation_gate_output =
3915           gemmlowp::tanh(input_modulation_gate_input);
3916       // Implementation of forget gate, using fixed-point logistic function.
3917       F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3918       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3919       // Implementation of output gate, using fixed-point logistic function.
3920       F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3921       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3922       // Implementation of internal multiplication nodes, still in fixed-point.
3923       F0 input_times_input_modulation =
3924           input_gate_output * input_modulation_gate_output;
3925       FS prev_state = FS::FromRaw(*prev_state_ptr++);
3926       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3927       // Implementation of internal addition node, saturating.
3928       FS new_state = gemmlowp::SaturatingAdd(
3929           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3930           prev_state_times_forget_state);
3931       // Implementation of last internal Tanh node, still in fixed-point.
3932       // Since a Tanh fixed-point implementation is specialized for a given
3933       // number or integer bits, and each specialization can have a substantial
3934       // code size, and we already used above a Tanh on an input with 3 integer
3935       // bits, and per the table in the above function comment there is no
3936       // significant accuracy to be lost by clamping to [-8, +8] for a
3937       // 3-integer-bits representation, let us just do that. This helps people
3938       // porting this to targets where code footprint must be minimized.
3939       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3940       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3941       // Store the new internal state back to memory, as 16-bit integers.
3942       // Note: here we store the original value with StateIntegerBits, not
3943       // the rescaled 3-integer-bits value fed to tanh.
3944       *output_state_data_ptr++ = new_state.raw();
3945       // Down-scale the output activations to 8-bit integers, saturating,
3946       // and store back to memory.
3947       int16 rescaled_output_activ =
3948           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3949       int16 clamped_output_activ =
3950           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3951       *output_activ_data_ptr++ = 128 + clamped_output_activ;
3952     }
3953     input_gate_input_ptr += 3 * output_depth;
3954     input_modulation_gate_input_ptr += 3 * output_depth;
3955     forget_gate_input_ptr += 3 * output_depth;
3956     output_gate_input_ptr += 3 * output_depth;
3957   }
3958 }
3959 
NodeOffset(int b,int h,int w,int height,int width)3960 inline int NodeOffset(int b, int h, int w, int height, int width) {
3961   return (b * height + h) * width + w;
3962 }
3963 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3964 inline void AveragePool(const PoolParams& params,
3965                         const RuntimeShape& input_shape,
3966                         const float* input_data,
3967                         const RuntimeShape& output_shape, float* output_data) {
3968   gemmlowp::ScopedProfilingLabel label("AveragePool");
3969   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3970   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3971   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3972   const int input_height = input_shape.Dims(1);
3973   const int input_width = input_shape.Dims(2);
3974   const int output_height = output_shape.Dims(1);
3975   const int output_width = output_shape.Dims(2);
3976   const int stride_height = params.stride_height;
3977   const int stride_width = params.stride_width;
3978 
3979   // TODO(benoitjacob) make this a proper reference impl without Eigen!
3980   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3981   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3982   // TODO(benoitjacob) get rid of the dynamic memory allocation here!
3983   Eigen::VectorXf out_count(out_mat.cols());
3984   out_count.setZero();
3985   // Prefill the output to 0.
3986   out_mat.setZero();
3987   for (int b = 0; b < batches; ++b) {
3988     for (int h = 0; h < input_height; ++h) {
3989       for (int w = 0; w < input_width; ++w) {
3990         // (h_start, h_end) * (w_start, w_end) is the range that the input
3991         // vector projects to.
3992         int hpad = h + params.padding_values.height;
3993         int wpad = w + params.padding_values.width;
3994         int h_start = (hpad < params.filter_height)
3995                           ? 0
3996                           : (hpad - params.filter_height) / stride_height + 1;
3997         int h_end = std::min(hpad / stride_height + 1, output_height);
3998         int w_start = (wpad < params.filter_width)
3999                           ? 0
4000                           : (wpad - params.filter_width) / stride_width + 1;
4001         int w_end = std::min(wpad / stride_width + 1, output_width);
4002         // compute elementwise sum
4003         for (int ph = h_start; ph < h_end; ++ph) {
4004           for (int pw = w_start; pw < w_end; ++pw) {
4005             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
4006             out_mat.col(out_offset) +=
4007                 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
4008             out_count(out_offset)++;
4009           }
4010         }
4011       }
4012     }
4013   }
4014   // Divide the output by the actual number of elements being averaged over
4015   TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
4016   out_mat.array().rowwise() /= out_count.transpose().array();
4017 
4018   const int flat_size = output_shape.FlatSize();
4019   for (int i = 0; i < flat_size; ++i) {
4020     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
4021                                                   params.float_activation_min,
4022                                                   params.float_activation_max);
4023   }
4024 }
4025 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4026 inline void AveragePool(const PoolParams& params,
4027                         const RuntimeShape& input_shape,
4028                         const uint8* input_data,
4029                         const RuntimeShape& output_shape, uint8* output_data) {
4030   gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
4031 
4032   // Here, and in other pooling ops, in order to maintain locality of reference,
4033   // to minimize some recalculations, and to load into NEON vector registers, we
4034   // use an inner loop down the depth. Since depths can be large and hence we
4035   // would need arbitrarily large temporary storage, we divide the work up into
4036   // depth tranches just within the batch loop.
4037   static constexpr int kPoolingAccTrancheSize = 256;
4038 
4039   TFLITE_DCHECK_LE(params.quantized_activation_min,
4040                    params.quantized_activation_max);
4041   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4042   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4043   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4044   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
4045   const int input_height = input_shape.Dims(1);
4046   const int input_width = input_shape.Dims(2);
4047   const int output_height = output_shape.Dims(1);
4048   const int output_width = output_shape.Dims(2);
4049   const int stride_height = params.stride_height;
4050   const int stride_width = params.stride_width;
4051 
4052   uint16 acc[kPoolingAccTrancheSize];
4053   for (int batch = 0; batch < batches; ++batch) {
4054     // We proceed through the depth in tranches (see comment above). The
4055     // depth_base is the depth at the beginning of the tranche. The
4056     // tranche_depth is the depth dimension of the tranche.
4057     for (int depth_base = 0; depth_base < depth;
4058          depth_base += kPoolingAccTrancheSize) {
4059       const int tranche_depth =
4060           std::min(depth - depth_base, kPoolingAccTrancheSize);
4061       for (int out_y = 0; out_y < output_height; ++out_y) {
4062         for (int out_x = 0; out_x < output_width; ++out_x) {
4063           const int in_x_origin =
4064               (out_x * stride_width) - params.padding_values.width;
4065           const int in_y_origin =
4066               (out_y * stride_height) - params.padding_values.height;
4067           const int filter_x_start = std::max(0, -in_x_origin);
4068           const int filter_x_end =
4069               std::min(params.filter_width, input_width - in_x_origin);
4070           const int filter_y_start = std::max(0, -in_y_origin);
4071           const int filter_y_end =
4072               std::min(params.filter_height, input_height - in_y_origin);
4073           const int filter_count =
4074               (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
4075           memset(acc, 0, tranche_depth * sizeof(acc[0]));
4076           const uint8* input_ptr =
4077               input_data + depth_base +
4078               depth * (in_x_origin +
4079                        input_width * (in_y_origin + input_height * batch));
4080           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
4081             const uint8* input_row_ptr =
4082                 input_ptr + depth * (fy * input_width + filter_x_start);
4083             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
4084               const uint8* input_channel_ptr = input_row_ptr;
4085               int channel = 0;
4086 #ifdef USE_NEON
4087               for (; channel <= tranche_depth - 16; channel += 16) {
4088                 uint16x8_t acc_reg[2];
4089                 for (int i = 0; i < 2; i++) {
4090                   acc_reg[i] = vld1q_u16(acc + channel + 8 * i);
4091                 }
4092                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
4093                 input_channel_ptr += 16;
4094                 acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg));
4095                 acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg));
4096                 for (int i = 0; i < 2; i++) {
4097                   vst1q_u16(acc + channel + 8 * i, acc_reg[i]);
4098                 }
4099               }
4100               for (; channel <= tranche_depth - 8; channel += 8) {
4101                 uint16x8_t acc_reg = vld1q_u16(acc + channel);
4102                 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
4103                 input_channel_ptr += 8;
4104                 acc_reg = vaddw_u8(acc_reg, input_reg);
4105                 vst1q_u16(acc + channel, acc_reg);
4106               }
4107 #endif
4108               for (; channel < tranche_depth; ++channel) {
4109                 acc[channel] += *input_channel_ptr++;
4110               }
4111               input_row_ptr += depth;
4112             }
4113           }
4114           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
4115                                                    out_x, depth_base);
4116           int channel = 0;
4117 #ifdef USE_NEON
4118 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT)                               \
4119   if (filter_count == FILTER_COUNT) {                                   \
4120     for (; channel <= tranche_depth - 8; channel += 8) {                \
4121       uint16 buf[8];                                                    \
4122       for (int i = 0; i < 8; i++) {                                     \
4123         buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT;  \
4124       }                                                                 \
4125       uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));                      \
4126       buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
4127       buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
4128       vst1_u8(output_ptr + channel, buf8);                              \
4129     }                                                                   \
4130   }
4131           AVGPOOL_DIVIDING_BY(9)
4132           AVGPOOL_DIVIDING_BY(15)
4133 #undef AVGPOOL_DIVIDING_BY
4134           for (; channel <= tranche_depth - 8; channel += 8) {
4135             uint16 buf[8];
4136             for (int i = 0; i < 8; i++) {
4137               buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
4138             }
4139             uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
4140             buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
4141             buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
4142             vst1_u8(output_ptr + channel, buf8);
4143           }
4144 #endif
4145           for (; channel < tranche_depth; ++channel) {
4146             uint16 a = (acc[channel] + filter_count / 2) / filter_count;
4147             a = std::max<uint16>(a, params.quantized_activation_min);
4148             a = std::min<uint16>(a, params.quantized_activation_max);
4149             output_ptr[channel] = static_cast<uint8>(a);
4150           }
4151         }
4152       }
4153     }
4154   }
4155 }
4156 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4157 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
4158                     const float* input_data, const RuntimeShape& output_shape,
4159                     float* output_data) {
4160   gemmlowp::ScopedProfilingLabel label("MaxPool");
4161   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4162   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4163   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4164   const int input_height = input_shape.Dims(1);
4165   const int input_width = input_shape.Dims(2);
4166   const int output_height = output_shape.Dims(1);
4167   const int output_width = output_shape.Dims(2);
4168   const int stride_height = params.stride_height;
4169   const int stride_width = params.stride_width;
4170 
4171   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4172   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4173   // Prefill the output to minimum representable float value
4174   out_mat.setConstant(std::numeric_limits<float>::lowest());
4175   for (int b = 0; b < batches; ++b) {
4176     for (int h = 0; h < input_height; ++h) {
4177       for (int w = 0; w < input_width; ++w) {
4178         // (h_start, h_end) * (w_start, w_end) is the range that the input
4179         // vector projects to.
4180         int hpad = h + params.padding_values.height;
4181         int wpad = w + params.padding_values.width;
4182         int h_start = (hpad < params.filter_height)
4183                           ? 0
4184                           : (hpad - params.filter_height) / stride_height + 1;
4185         int h_end = std::min(hpad / stride_height + 1, output_height);
4186         int w_start = (wpad < params.filter_width)
4187                           ? 0
4188                           : (wpad - params.filter_width) / stride_width + 1;
4189         int w_end = std::min(wpad / stride_width + 1, output_width);
4190         // compute elementwise sum
4191         for (int ph = h_start; ph < h_end; ++ph) {
4192           for (int pw = w_start; pw < w_end; ++pw) {
4193             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
4194             out_mat.col(out_offset) =
4195                 out_mat.col(out_offset)
4196                     .cwiseMax(in_mat.col(
4197                         NodeOffset(b, h, w, input_height, input_width)));
4198           }
4199         }
4200       }
4201     }
4202   }
4203   const int flat_size = output_shape.FlatSize();
4204   for (int i = 0; i < flat_size; ++i) {
4205     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
4206                                                   params.float_activation_min,
4207                                                   params.float_activation_max);
4208   }
4209 }
4210 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4211 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
4212                     const uint8* input_data, const RuntimeShape& output_shape,
4213                     uint8* output_data) {
4214   gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
4215 
4216   // Here, and in other pooling ops, in order to maintain locality of reference,
4217   // to minimize some recalculations, and to load into NEON vector registers, we
4218   // use an inner loop down the depth. Since depths can be large and hence we
4219   // would need arbitrarily large temporary storage, we divide the work up into
4220   // depth tranches just within the batch loop.
4221   static constexpr int kPoolingAccTrancheSize = 256;
4222 
4223   TFLITE_DCHECK_LE(params.quantized_activation_min,
4224                    params.quantized_activation_max);
4225   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4226   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4227   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4228   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
4229   const int input_height = input_shape.Dims(1);
4230   const int input_width = input_shape.Dims(2);
4231   const int output_height = output_shape.Dims(1);
4232   const int output_width = output_shape.Dims(2);
4233   const int stride_height = params.stride_height;
4234   const int stride_width = params.stride_width;
4235 
4236   uint8 acc[kPoolingAccTrancheSize];
4237   for (int batch = 0; batch < batches; ++batch) {
4238     // We proceed through the depth in tranches (see comment above). The
4239     // depth_base is the depth at the beginning of the tranche. The
4240     // tranche_depth is the depth dimension of the tranche.
4241     for (int depth_base = 0; depth_base < depth;
4242          depth_base += kPoolingAccTrancheSize) {
4243       const int tranche_depth =
4244           std::min(depth - depth_base, kPoolingAccTrancheSize);
4245       for (int out_y = 0; out_y < output_height; ++out_y) {
4246         for (int out_x = 0; out_x < output_width; ++out_x) {
4247           const int in_x_origin =
4248               (out_x * stride_width) - params.padding_values.width;
4249           const int in_y_origin =
4250               (out_y * stride_height) - params.padding_values.height;
4251           const int filter_x_start = std::max(0, -in_x_origin);
4252           const int filter_x_end =
4253               std::min(params.filter_width, input_width - in_x_origin);
4254           const int filter_y_start = std::max(0, -in_y_origin);
4255           const int filter_y_end =
4256               std::min(params.filter_height, input_height - in_y_origin);
4257           memset(acc, 0, tranche_depth * sizeof(acc[0]));
4258           const uint8* input_ptr =
4259               input_data + depth_base +
4260               depth * (in_x_origin +
4261                        input_width * (in_y_origin + input_height * batch));
4262           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
4263             const uint8* input_row_ptr =
4264                 input_ptr + depth * (fy * input_width + filter_x_start);
4265             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
4266               const uint8* input_channel_ptr = input_row_ptr;
4267               int channel = 0;
4268 #ifdef USE_NEON
4269               for (; channel <= tranche_depth - 16; channel += 16) {
4270                 uint8x16_t acc_reg = vld1q_u8(acc + channel);
4271                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
4272                 input_channel_ptr += 16;
4273                 acc_reg = vmaxq_u8(acc_reg, input_reg);
4274                 vst1q_u8(acc + channel, acc_reg);
4275               }
4276 
4277               for (; channel <= tranche_depth - 8; channel += 8) {
4278                 uint8x8_t acc_reg = vld1_u8(acc + channel);
4279                 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
4280                 input_channel_ptr += 8;
4281                 acc_reg = vmax_u8(acc_reg, input_reg);
4282                 vst1_u8(acc + channel, acc_reg);
4283               }
4284 #endif
4285               for (; channel < tranche_depth; ++channel) {
4286                 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
4287               }
4288               input_row_ptr += depth;
4289             }
4290           }
4291           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
4292                                                    out_x, depth_base);
4293           int channel = 0;
4294 #ifdef USE_NEON
4295           for (; channel <= tranche_depth - 16; channel += 16) {
4296             uint8x16_t a = vld1q_u8(acc + channel);
4297             a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
4298             a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
4299             vst1q_u8(output_ptr + channel, a);
4300           }
4301           for (; channel <= tranche_depth - 8; channel += 8) {
4302             uint8x8_t a = vld1_u8(acc + channel);
4303             a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
4304             a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
4305             vst1_u8(output_ptr + channel, a);
4306           }
4307 #endif
4308           for (; channel < tranche_depth; ++channel) {
4309             uint8 a = acc[channel];
4310             a = std::max<uint8>(a, params.quantized_activation_min);
4311             a = std::min<uint8>(a, params.quantized_activation_max);
4312             output_ptr[channel] = static_cast<uint8>(a);
4313           }
4314         }
4315       }
4316     }
4317   }
4318 }
4319 
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4320 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
4321                    const float* input_data, const RuntimeShape& output_shape,
4322                    float* output_data) {
4323   gemmlowp::ScopedProfilingLabel label("L2Pool");
4324   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4325   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4326   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4327   const int input_height = input_shape.Dims(1);
4328   const int input_width = input_shape.Dims(2);
4329   const int output_height = output_shape.Dims(1);
4330   const int output_width = output_shape.Dims(2);
4331   const int stride_height = params.stride_height;
4332   const int stride_width = params.stride_width;
4333   // Actually carry out L2 Pool. Code is written in forward mode: we go through
4334   // the input values once, and write to all the pooled regions that it maps to.
4335   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4336   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4337   Eigen::VectorXf in_square(in_mat.rows());
4338   Eigen::VectorXf out_count(out_mat.cols());
4339   out_count.setZero();
4340   // Prefill the output to 0.
4341   out_mat.setZero();
4342   for (int b = 0; b < batches; ++b) {
4343     for (int h = 0; h < input_height; ++h) {
4344       for (int w = 0; w < input_width; ++w) {
4345         // (h_start, h_end) * (w_start, w_end) is the range that the input
4346         // vector projects to.
4347         const int hpad = h + params.padding_values.height;
4348         const int wpad = w + params.padding_values.width;
4349         const int h_start =
4350             (hpad < params.filter_height)
4351                 ? 0
4352                 : (hpad - params.filter_height) / stride_height + 1;
4353         const int h_end = std::min(hpad / stride_height + 1, output_height);
4354         const int w_start =
4355             (wpad < params.filter_width)
4356                 ? 0
4357                 : (wpad - params.filter_width) / stride_width + 1;
4358         const int w_end = std::min(wpad / stride_width + 1, output_width);
4359         // pre-compute square
4360         const int in_offset = w + input_width * (h + input_height * b);
4361         in_square =
4362             in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
4363         // compute elementwise sum of squares
4364         for (int ph = h_start; ph < h_end; ++ph) {
4365           for (int pw = w_start; pw < w_end; ++pw) {
4366             const int out_offset = pw + output_width * (ph + output_height * b);
4367             out_mat.col(out_offset) += in_square;
4368             out_count(out_offset)++;
4369           }
4370         }
4371       }
4372     }
4373   }
4374 
4375   out_count = out_count.array().inverse();
4376   out_mat =
4377       (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
4378 
4379   const int flat_size = output_shape.FlatSize();
4380   for (int i = 0; i < flat_size; ++i) {
4381     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
4382                                                   params.float_activation_min,
4383                                                   params.float_activation_max);
4384   }
4385 }
4386 
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4387 inline void LocalResponseNormalization(
4388     const tflite::LocalResponseNormalizationParams& op_params,
4389     const RuntimeShape& input_shape, const float* input_data,
4390     const RuntimeShape& output_shape, float* output_data) {
4391   gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
4392   MatchingFlatSize(input_shape, output_shape);
4393 
4394   const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4395   auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4396 
4397   // Carry out local response normalization, vector by vector.
4398   // Since the data are stored column major, making row-wise operation
4399   // probably not memory efficient anyway, we do an explicit for loop over
4400   // the columns.
4401   const int double_range = op_params.range * 2;
4402   Eigen::VectorXf padded_square(data_in.rows() + double_range);
4403   padded_square.setZero();
4404   for (int r = 0; r < data_in.cols(); ++r) {
4405     // Do local response normalization for data_in(:, r)
4406     // first, compute the square and store them in buffer for repeated use
4407     padded_square.block(op_params.range, 0, data_in.rows(), 1) =
4408         data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
4409     // Then, compute the scale and writes them to data_out
4410     float accumulated_scale = 0;
4411     for (int i = 0; i < double_range; ++i) {
4412       accumulated_scale += padded_square(i);
4413     }
4414     for (int i = 0; i < data_in.rows(); ++i) {
4415       accumulated_scale += padded_square(i + double_range);
4416       data_out(i, r) = op_params.bias + accumulated_scale;
4417       accumulated_scale -= padded_square(i);
4418     }
4419   }
4420 
4421   // In a few cases, the pow computation could benefit from speedups.
4422   if (op_params.beta == 1) {
4423     data_out.array() = data_in.array() * data_out.array().inverse();
4424   } else if (op_params.beta == 0.5) {
4425     data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
4426   } else {
4427     data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
4428   }
4429 }
4430 
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4431 inline void Softmax(const SoftmaxParams& params,
4432                     const RuntimeShape& input_shape, const float* input_data,
4433                     const RuntimeShape& output_shape, float* output_data) {
4434   gemmlowp::ScopedProfilingLabel label("Softmax");
4435   MatchingFlatSize(input_shape, output_shape);
4436 
4437   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4438   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4439   // Compute the exponential first, removing the max coefficient for numerical
4440   // stability.
4441   out_mat =
4442       (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
4443   // We are separating out the exp function so that exp can be vectorized.
4444   out_mat = out_mat.array().exp();
4445   // Normalize to get the activations.
4446   Eigen::Array<float, 1, Eigen::Dynamic> scale =
4447       out_mat.array().colwise().sum().inverse();
4448   out_mat.array().rowwise() *= scale;
4449 }
4450 
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4451 inline void Softmax(const SoftmaxParams& params,
4452                     const RuntimeShape& input_shape, const uint8* input_data,
4453                     const RuntimeShape& output_shape, uint8* output_data) {
4454   const int32 input_beta_multiplier = params.input_multiplier;
4455   const int32 input_beta_left_shift = params.input_left_shift;
4456   const int diff_min = params.diff_min;
4457   // The representation chosen for the input to the exp() function is Q5.26.
4458   // We need to leave extra space since values that we skip might be as large as
4459   // -32 before multiplying by input_beta_multiplier, and therefore as large as
4460   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
4461   // accumulation, but exp(-16) definitely is.
4462   static const int kScaledDiffIntegerBits = 5;
4463   static const int kAccumulationIntegerBits = 12;
4464   using FixedPointScaledDiff =
4465       gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4466   using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4467   using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4468 
4469   gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
4470   const int trailing_dim = input_shape.DimensionsCount() - 1;
4471   const int outer_size =
4472       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4473   const int depth =
4474       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4475 
4476   for (int b = 0; b < outer_size; ++b) {
4477     const uint8* input_data_ptr = input_data + b * depth;
4478     uint8* output_data_ptr = output_data + b * depth;
4479 
4480     // Determine the largest entry in the current row
4481     uint8 max_in_row = 0;
4482     {
4483       int c = 0;
4484 #ifdef USE_NEON
4485       uint8x16_t max16_0 = vdupq_n_u8(0);
4486       uint8x16_t max16_1 = vdupq_n_u8(0);
4487       for (; c <= depth - 32; c += 32) {
4488         max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
4489         max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
4490       }
4491       uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
4492       if (c <= depth - 16) {
4493         max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
4494         c += 16;
4495       }
4496       uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
4497       if (c <= depth - 8) {
4498         max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
4499         c += 8;
4500       }
4501       uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
4502       uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
4503       uint8x8_t max1 = vpmax_u8(max2, max2);
4504       max_in_row = vget_lane_u8(max1, 0);
4505 #endif
4506       for (; c < depth; ++c) {
4507         max_in_row = std::max(max_in_row, input_data_ptr[c]);
4508       }
4509     }
4510 
4511 #ifdef USE_NEON
4512     using FixedPointAccumInt32x4 =
4513         gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
4514     using FixedPointScaledDiffInt32x4 =
4515         gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
4516     using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
4517     FixedPoint0Int32x4 input_beta_multiplier_f0 =
4518         FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
4519     int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
4520 #endif
4521 
4522     // Compute the sum of exponentials of the differences of entries in the
4523     // current row from the largest entry in the current row.
4524     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4525     {
4526       int c = 0;
4527 #ifdef USE_NEON
4528       int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
4529       FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
4530       FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
4531       FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
4532       for (; c <= depth - 8; c += 8) {
4533         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4534         int16x8_t input_diff_s16 =
4535             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4536         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4537         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4538         int32x4_t mask_0 =
4539             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
4540         int32x4_t mask_1 =
4541             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
4542         FixedPointScaledDiffInt32x4 scaled_diff_0 =
4543             input_beta_multiplier_f0 *
4544             FixedPointScaledDiffInt32x4::FromRaw(
4545                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4546         FixedPointScaledDiffInt32x4 scaled_diff_1 =
4547             input_beta_multiplier_f0 *
4548             FixedPointScaledDiffInt32x4::FromRaw(
4549                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4550         FixedPointAccumInt32x4 exps_0 =
4551             gemmlowp::Rescale<kAccumulationIntegerBits>(
4552                 exp_on_negative_values(scaled_diff_0));
4553         FixedPointAccumInt32x4 exps_1 =
4554             gemmlowp::Rescale<kAccumulationIntegerBits>(
4555                 exp_on_negative_values(scaled_diff_1));
4556         FixedPointAccumInt32x4 masked_exps_0 =
4557             SelectUsingMask(mask_0, exps_0, zeros);
4558         FixedPointAccumInt32x4 masked_exps_1 =
4559             SelectUsingMask(mask_1, exps_1, zeros);
4560         sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
4561         sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
4562       }
4563       int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
4564       int32x2_t sum_of_exps_reduced_2 =
4565           vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
4566                    vget_high_s32(sum_of_exps_reduced_4));
4567       int32x2_t sum_of_exps_reduced_1 =
4568           vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
4569       sum_of_exps =
4570           FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
4571 #endif
4572       for (; c < depth; ++c) {
4573         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4574         if (input_diff >= diff_min) {
4575           const int32 input_diff_rescaled =
4576               MultiplyByQuantizedMultiplierGreaterThanOne(
4577                   input_diff, input_beta_multiplier, input_beta_left_shift);
4578           const FixedPointScaledDiff scaled_diff_f8 =
4579               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4580           sum_of_exps =
4581               sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4582                                 exp_on_negative_values(scaled_diff_f8));
4583         }
4584       }
4585     }
4586 
4587     // Compute the fixed-point multiplier and shift that we need to apply to
4588     // perform a division by the above-computed sum-of-exponentials.
4589     int32 fixed_sum_of_exps = sum_of_exps.raw();
4590     int headroom_plus_one =
4591         CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
4592     // This is the number of bits to the left of the binary point above 1.0.
4593     // Consider fixed_sum_of_exps=1.25.  In that case shifted_scale=0.8 and
4594     // no later adjustment will be needed.
4595     int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
4596     int32 shifted_sum_minus_one = static_cast<int32>(
4597         (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
4598         (static_cast<uint32>(1) << 31));
4599     FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
4600         FixedPoint0::FromRaw(shifted_sum_minus_one));
4601 
4602     // Compute the quotients of exponentials of differences of entries in the
4603     // current row from the largest entry, over the previously-computed sum of
4604     // exponentials.
4605     {
4606       int c = 0;
4607 #ifdef USE_NEON
4608       int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
4609       for (; c <= depth - 8; c += 8) {
4610         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4611         int16x8_t input_diff_s16 =
4612             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4613         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4614         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4615         uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
4616         FixedPointScaledDiffInt32x4 scaled_diff_0 =
4617             input_beta_multiplier_f0 *
4618             FixedPointScaledDiffInt32x4::FromRaw(
4619                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4620         FixedPointScaledDiffInt32x4 scaled_diff_1 =
4621             input_beta_multiplier_f0 *
4622             FixedPointScaledDiffInt32x4::FromRaw(
4623                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4624         FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
4625         FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
4626         int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
4627             vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
4628             num_bits_over_unit + 31 - 8);
4629         int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
4630             vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
4631             num_bits_over_unit + 31 - 8);
4632         int16x8_t output_s16 =
4633             vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
4634         uint8x8_t output_u8 = vqmovun_s16(output_s16);
4635         uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
4636         vst1_u8(output_data_ptr + c, masked_output);
4637       }
4638 #endif
4639       for (; c < depth; ++c) {
4640         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4641         if (input_diff >= diff_min) {
4642           const int32 input_diff_rescaled =
4643               MultiplyByQuantizedMultiplierGreaterThanOne(
4644                   input_diff, input_beta_multiplier, input_beta_left_shift);
4645           const FixedPointScaledDiff scaled_diff_f8 =
4646               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4647 
4648           FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
4649           int32 unsat_output = gemmlowp::RoundingDivideByPOT(
4650               (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
4651 
4652           output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
4653 
4654         } else {
4655           output_data_ptr[c] = 0;
4656         }
4657       }
4658     }
4659   }
4660 }
4661 
4662 // TODO(myenik): This is the same as the reference implementation, not actually
4663 // optimized yet.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4664 inline void LogSoftmax(const SoftmaxParams& params,
4665                        const RuntimeShape& input_shape, const float* input_data,
4666                        const RuntimeShape& output_shape, float* output_data) {
4667   gemmlowp::ScopedProfilingLabel label("LogSoftmax");
4668   const int trailing_dim = input_shape.DimensionsCount() - 1;
4669   const int outer_size =
4670       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4671   const int depth =
4672       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4673 
4674   for (int i = 0; i < outer_size; ++i) {
4675     const float* block_input_data = input_data + i * depth;
4676     float* block_output_data = output_data + i * depth;
4677     // Find max element value which we'll use to ensure numerical stability
4678     // taking advantage of the following equality:
4679     // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
4680     float max = std::numeric_limits<float>::lowest();
4681     for (int c = 0; c < depth; ++c) {
4682       max = std::max(max, block_input_data[c]);
4683     }
4684 
4685     // Compute sum.
4686     float sum = 0.f;
4687     for (int c = 0; c < depth; ++c) {
4688       sum += std::exp(block_input_data[c] - max);
4689     }
4690 
4691     // Compute result.
4692     const float log_sum = std::log(sum);
4693     for (int c = 0; c < depth; ++c) {
4694       block_output_data[c] = block_input_data[c] - max - log_sum;
4695     }
4696   }
4697 }
4698 
4699 // Currently just a copy of the reference code.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4700 inline void LogSoftmax(const SoftmaxParams& params,
4701                        const RuntimeShape& input_shape, const uint8* input_data,
4702                        const RuntimeShape& output_shape, uint8* output_data) {
4703   gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
4704   const int32 input_multiplier = params.input_multiplier;
4705   const int32 input_left_shift = params.input_left_shift;
4706   const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
4707   const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
4708   const int diff_min = params.diff_min;
4709   // The representation chosen for the input to the exp() function is Q5.26.
4710   // We need to leave extra space since values that we skip might be as large as
4711   // -32 before multiplying by input_beta_multiplier, and therefore as large as
4712   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
4713   // accumulation, but exp(-16) definitely is.
4714   static constexpr int kScaledDiffIntegerBits = 5;
4715   static constexpr int kAccumulationIntegerBits = 12;
4716   static constexpr int kOutputIntegerBits = 4;
4717   using FixedPointScaledDiff =
4718       gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4719   using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4720 
4721   const int trailing_dim = input_shape.DimensionsCount() - 1;
4722   const int outer_size =
4723       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4724   const int depth =
4725       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4726 
4727   for (int i = 0; i < outer_size; ++i) {
4728     const uint8* block_input_data = input_data + i * depth;
4729     uint8* block_output_data = output_data + i * depth;
4730     uint8 max_in_row = 0;
4731     for (int c = 0; c < depth; ++c) {
4732       max_in_row = std::max(max_in_row, block_input_data[c]);
4733     }
4734 
4735     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4736     for (int c = 0; c < depth; ++c) {
4737       int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
4738       if (input_diff >= diff_min) {
4739         const int32 input_diff_rescaled =
4740             MultiplyByQuantizedMultiplierGreaterThanOne(
4741                 input_diff, input_multiplier, input_left_shift);
4742         const FixedPointScaledDiff scaled_diff_f8 =
4743             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4744         sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4745                                         exp_on_negative_values(scaled_diff_f8));
4746       }
4747     }
4748 
4749     const int32 fixed_log_sum_of_exps =
4750         log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
4751             sum_of_exps)
4752             .raw();
4753 
4754     // rescaled_diff_min is smallest representable in
4755     // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
4756     // log-sub-exps that will be subtracted in the loop.
4757     //
4758     // The thresholds diff_min, etc are negative.
4759     const int rescaled_diff_min =
4760         fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
4761     const int adjusted_diff_min =
4762         std::max(diff_min - 1,  // Note use of > below instead of >= above.
4763                  MultiplyByQuantizedMultiplierSmallerThanOneExp(
4764                      rescaled_diff_min, reverse_scaling_divisor,
4765                      -reverse_scaling_right_shift));
4766 
4767     for (int c = 0; c < depth; ++c) {
4768       int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
4769       if (input_diff > adjusted_diff_min) {
4770         const int32 input_diff_rescaled =
4771             MultiplyByQuantizedMultiplierGreaterThanOne(
4772                 input_diff, input_multiplier, input_left_shift);
4773         int32 unsat_output =
4774             gemmlowp::RoundingDivideByPOT(
4775                 (input_diff_rescaled - fixed_log_sum_of_exps),
4776                 31 - kScaledDiffIntegerBits - kOutputIntegerBits) +
4777             255;
4778 
4779         block_output_data[c] = static_cast<uint8>(
4780             std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
4781       } else {
4782         // Set output to smallest value.
4783         block_output_data[c] = 0;
4784       }
4785     }
4786   }
4787 }
4788 
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4789 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
4790                      const RuntimeShape& output_shape, float* output_data) {
4791   gemmlowp::ScopedProfilingLabel label("Logistic");
4792   auto input_map = MapAsVector(input_data, input_shape);
4793   auto output_map = MapAsVector(output_data, output_shape);
4794   output_map.array() =
4795       input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
4796 }
4797 
4798 // Convenience version that allows, for example, generated-code calls to be
4799 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4800 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
4801                      const float* input_data, const RuntimeShape& output_shape,
4802                      float* output_data) {
4803   // Drop params: not needed.
4804   Logistic(input_shape, input_data, output_shape, output_data);
4805 }
4806 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4807 inline void Logistic(const LogisticParams& params,
4808                      const RuntimeShape& input_shape, const uint8* input_data,
4809                      const RuntimeShape& output_shape, uint8* output_data) {
4810   gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
4811   const int32 input_zero_point = params.input_zero_point;
4812   const int32 input_range_radius = params.input_range_radius;
4813   const int32 input_multiplier = params.input_multiplier;
4814   const int input_left_shift = params.input_left_shift;
4815   const int size = MatchingFlatSize(input_shape, output_shape);
4816 
4817   int c = 0;
4818 #ifdef USE_NEON
4819   // Handle 16 values at a time
4820   for (; c <= size - 16; c += 16) {
4821     // Read input uint8 values, cast to int16 and subtract input_zero_point
4822     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4823     int16x8_t input_val_centered_0 =
4824         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4825                   vdupq_n_s16(input_zero_point));
4826     int16x8_t input_val_centered_1 =
4827         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4828                   vdupq_n_s16(input_zero_point));
4829 
4830     // Prepare the bit masks that we will use at the end to implement the logic
4831     // that was expressed in the scalar code with branching:
4832     //   if (input_val_centered < -input_range_radius) {
4833     //     output_val = 0;
4834     //   } else if (input_val_centered > input_range_radius) {
4835     //     output_val = 255;
4836     //   } else {
4837     //     ...
4838     uint16x8_t mask_rightclamp_0 =
4839         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4840     uint16x8_t mask_rightclamp_1 =
4841         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4842     uint16x8_t mask_leftclamp_0 =
4843         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4844     uint16x8_t mask_leftclamp_1 =
4845         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4846     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4847                                              vshrn_n_u16(mask_rightclamp_1, 8));
4848     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4849                                             vshrn_n_u16(mask_leftclamp_1, 8));
4850 
4851     // This performs what is expressed in the scalar code as
4852     // const int32 input_val_rescaled =
4853     //     MultiplyByQuantizedMultiplierGreaterThanOne(
4854     //         input_val_centered, input_multiplier, input_left_shift);
4855     int32x4_t input_val_rescaled_0 =
4856         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4857                   vdupq_n_s32(input_left_shift));
4858     int32x4_t input_val_rescaled_1 =
4859         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4860                   vdupq_n_s32(input_left_shift));
4861     int32x4_t input_val_rescaled_2 =
4862         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4863                   vdupq_n_s32(input_left_shift));
4864     int32x4_t input_val_rescaled_3 =
4865         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4866                   vdupq_n_s32(input_left_shift));
4867     input_val_rescaled_0 =
4868         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4869     input_val_rescaled_1 =
4870         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4871     input_val_rescaled_2 =
4872         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4873     input_val_rescaled_3 =
4874         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4875 
4876     // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
4877     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4878     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4879     const FixedPoint4 input_val_f4_0 =
4880         FixedPoint4::FromRaw(input_val_rescaled_0);
4881     const FixedPoint4 input_val_f4_1 =
4882         FixedPoint4::FromRaw(input_val_rescaled_1);
4883     const FixedPoint4 input_val_f4_2 =
4884         FixedPoint4::FromRaw(input_val_rescaled_2);
4885     const FixedPoint4 input_val_f4_3 =
4886         FixedPoint4::FromRaw(input_val_rescaled_3);
4887     const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
4888     const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
4889     const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
4890     const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
4891 
4892     // Divide by 2^23 as in the scalar code
4893     using gemmlowp::RoundingDivideByPOT;
4894     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
4895     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
4896     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
4897     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
4898 
4899     // Cast output values to uint8, saturating
4900     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4901                                               vqmovn_s32(output_val_s32_1));
4902     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4903                                               vqmovn_s32(output_val_s32_3));
4904     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4905                                            vqmovun_s16(output_val_s16_1));
4906 
4907     // Perform the bit-masking with the bit masks computed at the beginning,
4908     // see the comment there.
4909     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4910     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4911 
4912     // Store back to memory
4913     vst1q_u8(output_data + c, output_val_u8);
4914   }
4915 #endif
4916   // Leftover loop: handle one value at a time with scalar code.
4917   for (; c < size; ++c) {
4918     const uint8 input_val_u8 = input_data[c];
4919     const int32 input_val_centered =
4920         static_cast<int32>(input_val_u8) - input_zero_point;
4921     uint8 output_val;
4922     if (input_val_centered < -input_range_radius) {
4923       output_val = 0;
4924     } else if (input_val_centered > input_range_radius) {
4925       output_val = 255;
4926     } else {
4927       const int32 input_val_rescaled =
4928           MultiplyByQuantizedMultiplierGreaterThanOne(
4929               input_val_centered, input_multiplier, input_left_shift);
4930       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4931       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4932       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4933       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
4934       using gemmlowp::RoundingDivideByPOT;
4935       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
4936       if (output_val_s32 == 256) {
4937         output_val_s32 = 255;
4938       }
4939       TFLITE_DCHECK_GE(output_val_s32, 0);
4940       TFLITE_DCHECK_LE(output_val_s32, 255);
4941       output_val = static_cast<uint8>(output_val_s32);
4942     }
4943     output_data[c] = output_val;
4944   }
4945 }
4946 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4947 inline void Logistic(const LogisticParams& params,
4948                      const RuntimeShape& input_shape, const int16* input_data,
4949                      const RuntimeShape& output_shape, int16* output_data) {
4950   gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
4951   const int flat_size = MatchingFlatSize(input_shape, output_shape);
4952 
4953   for (int i = 0; i < flat_size; i++) {
4954   }
4955 
4956   int c = 0;
4957   const int16* input_data_ptr = input_data;
4958   int16* output_data_ptr = output_data;
4959 #ifdef GEMMLOWP_NEON
4960   {
4961     // F0 uses 0 integer bits, range [-1, 1].
4962     // This is the return type of math functions such as tanh, logistic,
4963     // whose range is in [-1, 1].
4964     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4965     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4966     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4967 
4968     for (; c <= flat_size - 16; c += 16) {
4969       F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4970       F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4971       F0 output0 = gemmlowp::logistic(input0);
4972       F0 output1 = gemmlowp::logistic(input1);
4973       vst1q_s16(output_data_ptr, output0.raw());
4974       vst1q_s16(output_data_ptr + 8, output1.raw());
4975 
4976       input_data_ptr += 16;
4977       output_data_ptr += 16;
4978     }
4979     for (; c <= flat_size - 8; c += 8) {
4980       F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4981       F0 output = gemmlowp::logistic(input);
4982       vst1q_s16(output_data_ptr, output.raw());
4983 
4984       input_data_ptr += 8;
4985       output_data_ptr += 8;
4986     }
4987   }
4988 #endif
4989   {
4990     // F0 uses 0 integer bits, range [-1, 1].
4991     // This is the return type of math functions such as tanh, logistic,
4992     // whose range is in [-1, 1].
4993     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4994     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4995     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4996 
4997     for (; c < flat_size; ++c) {
4998       F3 input = F3::FromRaw(*input_data_ptr);
4999       F0 output = gemmlowp::logistic(input);
5000       *output_data_ptr = output.raw();
5001 
5002       ++input_data_ptr;
5003       ++output_data_ptr;
5004     }
5005   }
5006 }
5007 
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5008 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
5009                  const RuntimeShape& output_shape, float* output_data) {
5010   gemmlowp::ScopedProfilingLabel label("Tanh");
5011   auto input_map = MapAsVector(input_data, input_shape);
5012   auto output_map = MapAsVector(output_data, output_shape);
5013   output_map.array() = input_map.array().tanh();
5014 }
5015 
5016 // Convenience version that allows, for example, generated-code calls to be
5017 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5018 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
5019                  const float* input_data, const RuntimeShape& output_shape,
5020                  float* output_data) {
5021   // Drop params: not needed.
5022   Tanh(input_shape, input_data, output_shape, output_data);
5023 }
5024 
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)5025 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
5026                  const uint8* input_data, const RuntimeShape& output_shape,
5027                  uint8* output_data) {
5028   // Note that this is almost the exact same code as in Logistic().
5029   gemmlowp::ScopedProfilingLabel label("Tanh");
5030   const int32 input_zero_point = params.input_zero_point;
5031   const int32 input_range_radius = params.input_range_radius;
5032   const int32 input_multiplier = params.input_multiplier;
5033   const int input_left_shift = params.input_left_shift;
5034   const int size = MatchingFlatSize(input_shape, output_shape);
5035 
5036   int c = 0;
5037   int32_t output_zero_point = 128;
5038 #ifdef USE_NEON
5039   // Handle 16 values at a time
5040   for (; c <= size - 16; c += 16) {
5041     // Read input uint8 values, cast to int16 and subtract input_zero_point
5042     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
5043     int16x8_t input_val_centered_0 =
5044         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
5045                   vdupq_n_s16(input_zero_point));
5046     int16x8_t input_val_centered_1 =
5047         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
5048                   vdupq_n_s16(input_zero_point));
5049 
5050     // Prepare the bit masks that we will use at the end to implement the logic
5051     // that was expressed in the scalar code with branching:
5052     //   if (input_val_centered < -input_range_radius) {
5053     //     output_val = 0;
5054     //   } else if (input_val_centered > input_range_radius) {
5055     //     output_val = 255;
5056     //   } else {
5057     //     ...
5058     uint16x8_t mask_rightclamp_0 =
5059         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
5060     uint16x8_t mask_rightclamp_1 =
5061         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
5062     uint16x8_t mask_leftclamp_0 =
5063         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
5064     uint16x8_t mask_leftclamp_1 =
5065         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
5066     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
5067                                              vshrn_n_u16(mask_rightclamp_1, 8));
5068     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
5069                                             vshrn_n_u16(mask_leftclamp_1, 8));
5070 
5071     // This performs what is expressed in the scalar code as
5072     // const int32 input_val_rescaled =
5073     //     MultiplyByQuantizedMultiplierGreaterThanOne(
5074     //         input_val_centered, input_multiplier, input_left_shift);
5075     int32x4_t input_val_rescaled_0 =
5076         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
5077                   vdupq_n_s32(input_left_shift));
5078     int32x4_t input_val_rescaled_1 =
5079         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
5080                   vdupq_n_s32(input_left_shift));
5081     int32x4_t input_val_rescaled_2 =
5082         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
5083                   vdupq_n_s32(input_left_shift));
5084     int32x4_t input_val_rescaled_3 =
5085         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
5086                   vdupq_n_s32(input_left_shift));
5087     input_val_rescaled_0 =
5088         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
5089     input_val_rescaled_1 =
5090         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
5091     input_val_rescaled_2 =
5092         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
5093     input_val_rescaled_3 =
5094         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
5095 
5096     // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
5097     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
5098     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
5099     const FixedPoint4 input_val_f4_0 =
5100         FixedPoint4::FromRaw(input_val_rescaled_0);
5101     const FixedPoint4 input_val_f4_1 =
5102         FixedPoint4::FromRaw(input_val_rescaled_1);
5103     const FixedPoint4 input_val_f4_2 =
5104         FixedPoint4::FromRaw(input_val_rescaled_2);
5105     const FixedPoint4 input_val_f4_3 =
5106         FixedPoint4::FromRaw(input_val_rescaled_3);
5107     const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
5108     const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
5109     const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
5110     const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
5111 
5112     // Divide by 2^24 as in the scalar code
5113     using gemmlowp::RoundingDivideByPOT;
5114     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
5115     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
5116     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
5117     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
5118 
5119     // Add the output zero point
5120     int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
5121     output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
5122     output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
5123     output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
5124     output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
5125 
5126     // Cast output values to uint8, saturating
5127     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
5128                                               vqmovn_s32(output_val_s32_1));
5129     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
5130                                               vqmovn_s32(output_val_s32_3));
5131     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
5132                                            vqmovun_s16(output_val_s16_1));
5133 
5134     // Perform the bit-masking with the bit masks computed at the beginning,
5135     // see the comment there.
5136     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
5137     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
5138 
5139     // Store back to memory
5140     vst1q_u8(output_data + c, output_val_u8);
5141   }
5142 #endif
5143   // Leftover loop: handle one value at a time with scalar code.
5144   for (; c < size; ++c) {
5145     const uint8 input_val_u8 = input_data[c];
5146     const int32 input_val_centered =
5147         static_cast<int32>(input_val_u8) - input_zero_point;
5148     uint8 output_val;
5149     if (input_val_centered < -input_range_radius) {
5150       output_val = 0;
5151     } else if (input_val_centered > input_range_radius) {
5152       output_val = 255;
5153     } else {
5154       const int32 input_val_rescaled =
5155           MultiplyByQuantizedMultiplierGreaterThanOne(
5156               input_val_centered, input_multiplier, input_left_shift);
5157       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
5158       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
5159       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
5160       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
5161       using gemmlowp::RoundingDivideByPOT;
5162       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
5163       output_val_s32 += output_zero_point;
5164       if (output_val_s32 == 256) {
5165         output_val_s32 = 255;
5166       }
5167       TFLITE_DCHECK_GE(output_val_s32, 0);
5168       TFLITE_DCHECK_LE(output_val_s32, 255);
5169       output_val = static_cast<uint8>(output_val_s32);
5170     }
5171     output_data[c] = output_val;
5172   }
5173 }
5174 
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)5175 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
5176                  const int16* input_data, const RuntimeShape& output_shape,
5177                  int16* output_data) {
5178   gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
5179   const int input_left_shift = params.input_left_shift;
5180   // Support for shifts is limited until we have a parameterized version of
5181   // SaturatingRoundingMultiplyByPOT().
5182   TFLITE_DCHECK_GE(input_left_shift, 0);
5183   TFLITE_DCHECK_LE(input_left_shift, 1);
5184 
5185   const int flat_size = MatchingFlatSize(input_shape, output_shape);
5186 
5187   int c = 0;
5188   const int16* input_data_ptr = input_data;
5189   int16* output_data_ptr = output_data;
5190 #ifdef GEMMLOWP_NEON
5191   {
5192     // F0 uses 0 integer bits, range [-1, 1].
5193     // This is the return type of math functions such as tanh, logistic,
5194     // whose range is in [-1, 1].
5195     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
5196     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
5197     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
5198 
5199     if (input_left_shift == 0) {
5200       for (; c <= flat_size - 16; c += 16) {
5201         F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
5202         F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
5203         F0 output0 = gemmlowp::tanh(input0);
5204         F0 output1 = gemmlowp::tanh(input1);
5205         vst1q_s16(output_data_ptr, output0.raw());
5206         vst1q_s16(output_data_ptr + 8, output1.raw());
5207 
5208         input_data_ptr += 16;
5209         output_data_ptr += 16;
5210       }
5211       for (; c <= flat_size - 8; c += 8) {
5212         F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
5213         F0 output = gemmlowp::tanh(input);
5214         vst1q_s16(output_data_ptr, output.raw());
5215 
5216         input_data_ptr += 8;
5217         output_data_ptr += 8;
5218       }
5219     } else {
5220       for (; c <= flat_size - 16; c += 16) {
5221         F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
5222             vld1q_s16(input_data_ptr)));
5223         F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
5224             vld1q_s16(input_data_ptr + 8)));
5225         F0 output0 = gemmlowp::tanh(input0);
5226         F0 output1 = gemmlowp::tanh(input1);
5227         vst1q_s16(output_data_ptr, output0.raw());
5228         vst1q_s16(output_data_ptr + 8, output1.raw());
5229 
5230         input_data_ptr += 16;
5231         output_data_ptr += 16;
5232       }
5233       for (; c <= flat_size - 8; c += 8) {
5234         F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
5235             vld1q_s16(input_data_ptr)));
5236         F0 output = gemmlowp::tanh(input);
5237         vst1q_s16(output_data_ptr, output.raw());
5238 
5239         input_data_ptr += 8;
5240         output_data_ptr += 8;
5241       }
5242     }
5243   }
5244 #endif
5245   {
5246     // F0 uses 0 integer bits, range [-1, 1].
5247     // This is the return type of math functions such as tanh, logistic,
5248     // whose range is in [-1, 1].
5249     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
5250     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
5251     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
5252 
5253     if (input_left_shift == 0) {
5254       for (; c < flat_size; ++c) {
5255         F3 input = F3::FromRaw(*input_data_ptr);
5256         F0 output = gemmlowp::tanh(input);
5257         *output_data_ptr = output.raw();
5258 
5259         ++input_data_ptr;
5260         ++output_data_ptr;
5261       }
5262     } else {
5263       for (; c < flat_size; ++c) {
5264         F3 input = F3::FromRaw(
5265             gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
5266         F0 output = gemmlowp::tanh(input);
5267         *output_data_ptr = output.raw();
5268 
5269         ++input_data_ptr;
5270         ++output_data_ptr;
5271       }
5272     }
5273   }
5274 }
5275 
5276 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)5277 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
5278                  const RuntimeShape& output_shape, DstT* output_data) {
5279   gemmlowp::ScopedProfilingLabel label("Cast");
5280   auto input_map = MapAsVector(input_data, input_shape);
5281   auto output_map = MapAsVector(output_data, output_shape);
5282   output_map.array() = input_map.array().template cast<DstT>();
5283 }
5284 
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5285 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
5286                   const RuntimeShape& output_shape, float* output_data) {
5287   gemmlowp::ScopedProfilingLabel label("Floor");
5288   auto input_map = MapAsVector(input_data, input_shape);
5289   auto output_map = MapAsVector(output_data, output_shape);
5290   output_map.array() = Eigen::floor(input_map.array());
5291 }
5292 
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5293 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
5294                  const RuntimeShape& output_shape, float* output_data) {
5295   gemmlowp::ScopedProfilingLabel label("Ceil");
5296   auto input_map = MapAsVector(input_data, input_shape);
5297   auto output_map = MapAsVector(output_data, output_shape);
5298   output_map.array() = Eigen::ceil(input_map.array());
5299 }
5300 
5301 #ifdef USE_NEON
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)5302 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
5303                                  float scale, float* output_ptr) {
5304   int ic = 0;
5305   // Handle 32 input channels at a time.
5306   for (; ic <= depth - 32; ic += 32) {
5307     float32x4x2_t input[4];
5308     for (int i = 0; i < 4; i++) {
5309       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
5310       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
5311     }
5312     float32x4x2_t acc[4];
5313     for (int i = 0; i < 4; i++) {
5314       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
5315       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
5316     }
5317     for (int i = 0; i < 4; i++) {
5318       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
5319       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
5320     }
5321     for (int i = 0; i < 4; i++) {
5322       vst1q_f32(output_ptr, acc[i].val[0]);
5323       vst1q_f32(output_ptr + 4, acc[i].val[1]);
5324       output_ptr += 8;
5325     }
5326     input_ptr += 32;
5327   }
5328   // Handle 16 input channels at a time.
5329   for (; ic <= depth - 16; ic += 16) {
5330     float32x4x2_t input[2];
5331     for (int i = 0; i < 2; i++) {
5332       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
5333       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
5334     }
5335     float32x4x2_t acc[2];
5336     for (int i = 0; i < 2; i++) {
5337       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
5338       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
5339     }
5340     for (int i = 0; i < 2; i++) {
5341       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
5342       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
5343     }
5344     for (int i = 0; i < 2; i++) {
5345       vst1q_f32(output_ptr, acc[i].val[0]);
5346       vst1q_f32(output_ptr + 4, acc[i].val[1]);
5347       output_ptr += 8;
5348     }
5349     input_ptr += 16;
5350   }
5351   // Handle 8 input channels at a time.
5352   for (; ic <= depth - 8; ic += 8) {
5353     float32x4x2_t input;
5354     input.val[0] = vld1q_f32(input_ptr);
5355     input.val[1] = vld1q_f32(input_ptr + 4);
5356 
5357     float32x4x2_t acc;
5358     acc.val[0] = vld1q_f32(output_ptr);
5359     acc.val[1] = vld1q_f32(output_ptr + 4);
5360     acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
5361     acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
5362 
5363     vst1q_f32(output_ptr, acc.val[0]);
5364     vst1q_f32(output_ptr + 4, acc.val[1]);
5365 
5366     input_ptr += 8;
5367     output_ptr += 8;
5368   }
5369   // Handle 4 input channels at a time.
5370   for (; ic <= depth - 4; ic += 4) {
5371     float32x4_t input = vld1q_f32(input_ptr);
5372     float32x4_t acc = vld1q_f32(output_ptr);
5373 
5374     acc = vmlaq_n_f32(acc, input, scale);
5375     vst1q_f32(output_ptr, acc);
5376 
5377     input_ptr += 4;
5378     output_ptr += 4;
5379   }
5380   // Handle 1 input channel at a time.
5381   for (; ic < depth; ic++) {
5382     *output_ptr += *input_ptr * scale;
5383     output_ptr++;
5384     input_ptr++;
5385   }
5386 }
5387 #else
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)5388 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
5389                                  float scale, float* output_ptr) {
5390   for (int32 i = 0; i < depth; i++) {
5391     *output_ptr += *input_ptr * scale;
5392     output_ptr++;
5393     input_ptr++;
5394   }
5395 }
5396 #endif
5397 
ResizeBilinearKernel2x2(int32 x0,int32 x1,int32 y0,int32 y1,int32 x,int32 y,int32 depth,int32 batch,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5398 inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
5399                                     int32 x, int32 y, int32 depth, int32 batch,
5400                                     const RuntimeShape& input_shape,
5401                                     const float* input_data,
5402                                     const RuntimeShape& output_shape,
5403                                     float* output_data) {
5404   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5405   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
5406   const int32 input_width = input_shape.Dims(2);
5407   const int32 output_width = output_shape.Dims(2);
5408 
5409   const int32 input_x_offset = (x1 - x0) * depth;
5410   const int32 input_y_offset = (y1 - y0) * depth * input_width;
5411   const int32 output_x_offset = depth;
5412   const int32 output_y_offset = depth * output_width;
5413 
5414 #ifdef USE_NEON
5415   TFLITE_DCHECK(x1 >= x0);
5416   TFLITE_DCHECK(y1 >= y0);
5417 
5418   int ic = 0;
5419   // Handle 8 input channels at a time.
5420   for (; ic <= depth - 8; ic += 8) {
5421     const float* input_ptr = nullptr;
5422 
5423     float32x4x2_t x0y0;
5424     input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
5425     x0y0.val[0] = vld1q_f32(input_ptr);
5426     x0y0.val[1] = vld1q_f32(input_ptr + 4);
5427 
5428     float32x4x2_t x1y0;
5429     input_ptr += input_x_offset;
5430     x1y0.val[0] = vld1q_f32(input_ptr);
5431     x1y0.val[1] = vld1q_f32(input_ptr + 4);
5432 
5433     float32x4x2_t x0y1;
5434     input_ptr += -input_x_offset + input_y_offset;
5435     x0y1.val[0] = vld1q_f32(input_ptr);
5436     x0y1.val[1] = vld1q_f32(input_ptr + 4);
5437 
5438     float32x4x2_t x1y1;
5439     input_ptr += input_x_offset;
5440     x1y1.val[0] = vld1q_f32(input_ptr);
5441     x1y1.val[1] = vld1q_f32(input_ptr + 4);
5442 
5443     // Top left corner.
5444     float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
5445     vst1q_f32(output_ptr, x0y0.val[0]);
5446     vst1q_f32(output_ptr + 4, x0y0.val[1]);
5447 
5448     // Top right corner.
5449     output_ptr += output_x_offset;
5450     float32x4x2_t tr;
5451     tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
5452     tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
5453     tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
5454     tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
5455 
5456     vst1q_f32(output_ptr, tr.val[0]);
5457     vst1q_f32(output_ptr + 4, tr.val[1]);
5458 
5459     // Bottom left corner.
5460     output_ptr += -output_x_offset + output_y_offset;
5461     float32x4x2_t bl;
5462     bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
5463     bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
5464     bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
5465     bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
5466     vst1q_f32(output_ptr, bl.val[0]);
5467     vst1q_f32(output_ptr + 4, bl.val[1]);
5468 
5469     // Bottom right corner.
5470     output_ptr += output_x_offset;
5471     float32x4x2_t br;
5472     br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
5473     br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
5474     br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
5475     br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
5476     br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
5477     br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
5478     vst1q_f32(output_ptr, br.val[0]);
5479     vst1q_f32(output_ptr + 4, br.val[1]);
5480   }
5481   // Handle 4 input channels at a time.
5482   for (; ic <= depth - 4; ic += 4) {
5483     const float* input_ptr =
5484         &input_data[Offset(input_shape, batch, y0, x0, ic)];
5485     float32x4_t x0y0 = vld1q_f32(input_ptr);
5486     float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
5487     float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
5488     float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
5489 
5490     // Top left corner.
5491     float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
5492     vst1q_f32(output_ptr, x0y0);
5493 
5494     // Top right corner.
5495     output_ptr += output_x_offset;
5496     float32x4_t tr = vaddq_f32(x0y0, x1y0);
5497     tr = vmulq_n_f32(tr, 0.5f);
5498     vst1q_f32(output_ptr, tr);
5499 
5500     // Bottom left corner.
5501     output_ptr += -output_x_offset + output_y_offset;
5502     float32x4_t bl = vaddq_f32(x0y0, x0y1);
5503     bl = vmulq_n_f32(bl, 0.5f);
5504     vst1q_f32(output_ptr, bl);
5505 
5506     // Bottom right corner.
5507     output_ptr += output_x_offset;
5508     float32x4_t br = vaddq_f32(x1y0, x1y1);
5509     br = vmlaq_n_f32(bl, br, 0.5f);
5510     br = vmulq_n_f32(br, 0.5f);
5511     vst1q_f32(output_ptr, br);
5512   }
5513   // Handle one input channel at a time.
5514   for (; ic < depth; ic++) {
5515     const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
5516 
5517     float x0y0 = input_data[input_offset];
5518     float x1y0 = input_data[input_offset + input_x_offset];
5519     float x0y1 = input_data[input_offset + input_y_offset];
5520     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
5521 
5522     // Top left corner.
5523     const int32 output_offset = Offset(output_shape, batch, y, x, ic);
5524     output_data[output_offset] = x0y0;
5525 
5526     // Top right corner.
5527     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
5528 
5529     // Bottom left corner.
5530     float output = (x0y0 + x0y1) / 2;
5531     output_data[output_offset + output_y_offset] = output;
5532 
5533     // Bottom right corner.
5534     output_data[output_offset + output_x_offset + output_y_offset] =
5535         (output + ((x1y0 + x1y1) / 2)) / 2;
5536   }
5537 #else
5538   for (int ch = 0; ch < depth; ch++) {
5539     const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
5540 
5541     float x0y0 = input_data[input_offset];
5542     float x1y0 = input_data[input_offset + input_x_offset];
5543     float x0y1 = input_data[input_offset + input_y_offset];
5544     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
5545 
5546     // Top left corner.
5547     const int32 output_offset = Offset(output_shape, batch, y, x, ch);
5548     output_data[output_offset] = x0y0;
5549 
5550     // Top right corner.
5551     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
5552 
5553     // Bottom left corner.
5554     float output = (x0y0 + x0y1) / 2;
5555     output_data[output_offset + output_y_offset] = output;
5556 
5557     // Bottom right corner.
5558     output_data[output_offset + output_x_offset + output_y_offset] =
5559         (output + ((x1y0 + x1y1) / 2)) / 2;
5560   }
5561 #endif
5562 }
5563 
ResizeBilinear2x2(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5564 inline void ResizeBilinear2x2(int32 batches, int32 input_height,
5565                               int32 input_width, int32 depth,
5566                               int32 output_height, int32 output_width,
5567                               const RuntimeShape& input_shape,
5568                               const float* input_data,
5569                               const RuntimeShape& output_shape,
5570                               float* output_data) {
5571   for (int b = 0; b < batches; b++) {
5572     for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
5573       for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
5574         int32 x1 = std::min(x0 + 1, input_width - 1);
5575         int32 y1 = std::min(y0 + 1, input_height - 1);
5576         ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
5577                                 input_data, output_shape, output_data);
5578       }
5579     }
5580   }
5581 }
5582 
ResizeBilinearGeneric(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5583 inline void ResizeBilinearGeneric(
5584     int32 batches, int32 input_height, int32 input_width, int32 depth,
5585     int32 output_height, int32 output_width, float height_scale,
5586     float width_scale, const RuntimeShape& input_shape, const float* input_data,
5587     const RuntimeShape& output_shape, float* output_data) {
5588   memset(output_data, 0,
5589          batches * output_height * output_width * depth * sizeof(float));
5590 
5591   int32 output_offset = 0;
5592   for (int b = 0; b < batches; ++b) {
5593     for (int y = 0; y < output_height; ++y) {
5594       float input_y = y * height_scale;
5595       int32 y0 = static_cast<int32>(std::floor(input_y));
5596       int32 y1 = std::min(y0 + 1, input_height - 1);
5597       for (int x = 0; x < output_width; ++x) {
5598         float input_x = x * width_scale;
5599         int32 x0 = static_cast<int32>(input_x);
5600         int32 x1 = std::min(x0 + 1, input_width - 1);
5601         float* output_ptr = &output_data[output_offset];
5602 
5603         // Run kernel on the 4 corners of the bilinear resize algorithm.
5604         int32 input_offset = Offset(input_shape, b, y0, x0, 0);
5605         float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
5606         const float* input_ptr = &input_data[input_offset];
5607         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5608 
5609         input_offset = Offset(input_shape, b, y0, x1, 0);
5610         scale = (1 - (input_y - y0)) * (input_x - x0);
5611         input_ptr = &input_data[input_offset];
5612         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5613 
5614         input_offset = Offset(input_shape, b, y1, x0, 0);
5615         scale = (input_y - y0) * (1 - (input_x - x0));
5616         input_ptr = &input_data[input_offset];
5617         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5618 
5619         input_offset = Offset(input_shape, b, y1, x1, 0);
5620         scale = (input_y - y0) * (input_x - x0);
5621         input_ptr = &input_data[input_offset];
5622         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5623 
5624         output_offset += depth;
5625       }
5626     }
5627   }
5628 }
5629 
5630 template <typename T>
ResizeBilinearGenericSmallChannel(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)5631 inline void ResizeBilinearGenericSmallChannel(
5632     int32 batches, int32 input_height, int32 input_width, int32 depth,
5633     int32 output_height, int32 output_width, float height_scale,
5634     float width_scale, const RuntimeShape& input_shape, const T* input_data,
5635     const RuntimeShape& output_shape, T* output_data) {
5636   T* output_ptr = &output_data[0];
5637   for (int b = 0; b < batches; ++b) {
5638     for (int y = 0; y < output_height; ++y) {
5639       float input_y = y * height_scale;
5640       int32 y0 = static_cast<int32>(std::floor(input_y));
5641       int32 y1 = std::min(y0 + 1, input_height - 1);
5642       for (int x = 0; x < output_width; ++x) {
5643         float input_x = x * width_scale;
5644         int32 x0 = static_cast<int32>(std::floor((input_x)));
5645         int32 x1 = std::min(x0 + 1, input_width - 1);
5646 
5647         int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
5648                                  Offset(input_shape, b, y0, x1, 0),
5649                                  Offset(input_shape, b, y1, x0, 0),
5650                                  Offset(input_shape, b, y1, x1, 0)};
5651         float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
5652                           (1 - (input_y - y0)) * (input_x - x0),
5653                           (input_y - y0) * (1 - (input_x - x0)),
5654                           (input_y - y0) * (input_x - x0)};
5655 
5656         for (int d = 0; d < depth; d++) {
5657           const T* input_ptr = &input_data[d];
5658           *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
5659                                          input_ptr[input_offset[1]] * scale[1] +
5660                                          input_ptr[input_offset[2]] * scale[2] +
5661                                          input_ptr[input_offset[3]] * scale[3]);
5662         }
5663       }
5664     }
5665   }
5666 }
5667 
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,float * output_data)5668 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
5669                            const RuntimeShape& unextended_input_shape,
5670                            const float* input_data,
5671                            const RuntimeShape& output_size_shape,
5672                            const int32* output_size_data,
5673                            const RuntimeShape& unextended_output_shape,
5674                            float* output_data) {
5675   gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
5676   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5677   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5678   const RuntimeShape input_shape =
5679       RuntimeShape::ExtendedShape(4, unextended_input_shape);
5680   const RuntimeShape output_shape =
5681       RuntimeShape::ExtendedShape(4, unextended_output_shape);
5682 
5683   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5684   int32 input_height = input_shape.Dims(1);
5685   int32 input_width = input_shape.Dims(2);
5686   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5687 
5688   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5689   int32 output_height = output_size_data[0];
5690   int32 output_width = output_size_data[1];
5691 
5692   // Specialize for 2x2 upsample.
5693   if (!op_params.align_corners && output_height == 2 * input_height &&
5694       output_width == 2 * input_width) {
5695     ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
5696                       output_width, input_shape, input_data, output_shape,
5697                       output_data);
5698   } else {
5699     float height_scale = static_cast<float>(input_height) / output_height;
5700     float width_scale = static_cast<float>(input_width) / output_width;
5701     if (op_params.align_corners && output_height > 1) {
5702       height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
5703     }
5704     if (op_params.align_corners && output_width > 1) {
5705       width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
5706     }
5707 
5708     ResizeBilinearGeneric(batches, input_height, input_width, depth,
5709                           output_height, output_width, height_scale,
5710                           width_scale, input_shape, input_data, output_shape,
5711                           output_data);
5712   }
5713 }
5714 
5715 // TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
5716 // or int16 arithmetic.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)5717 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
5718                            const RuntimeShape& unextended_input_shape,
5719                            const uint8* input_data,
5720                            const RuntimeShape& output_size_shape,
5721                            const int32* output_size_data,
5722                            const RuntimeShape& unextended_output_shape,
5723                            uint8* output_data) {
5724   gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
5725   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5726   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5727   const RuntimeShape input_shape =
5728       RuntimeShape::ExtendedShape(4, unextended_input_shape);
5729   const RuntimeShape output_shape =
5730       RuntimeShape::ExtendedShape(4, unextended_output_shape);
5731 
5732   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5733   int32 input_height = input_shape.Dims(1);
5734   int32 input_width = input_shape.Dims(2);
5735   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5736 
5737   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5738   int32 output_height = output_size_data[0];
5739   int32 output_width = output_size_data[1];
5740 
5741   float height_scale =
5742       (op_params.align_corners && output_height > 1)
5743           ? (static_cast<float>(input_height - 1) / (output_height - 1))
5744           : (static_cast<float>(input_height) / output_height);
5745 
5746   float width_scale =
5747       (op_params.align_corners && output_width > 1)
5748           ? (static_cast<float>(input_width - 1) / (output_width - 1))
5749           : (static_cast<float>(input_width) / output_width);
5750 
5751   ResizeBilinearGenericSmallChannel<uint8>(
5752       batches, input_height, input_width, depth, output_height, output_width,
5753       height_scale, width_scale, input_shape, input_data, output_shape,
5754       output_data);
5755 }
5756 
5757 // Helper methods for BatchToSpaceND.
5758 // `spatial_index_dim` specifies post-crop offset index in this spatial
5759 // dimension, i.e. spatial offset introduced by flattening batch to spatial
5760 // dimension minus the crop size at beginning. `block_shape_dim` is the block
5761 // size in current dimension. `input_dim` and `output_dim` are input and output
5762 // size of BatchToSpaceND operation in current dimension.
5763 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)5764 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
5765                           int input_dim, int output_dim, int* start_index,
5766                           int* end_index) {
5767   // (*start_index) * block_shape_dim is effectively rounded up to the next
5768   // multiple of block_shape_dim by the integer division.
5769   *start_index =
5770       std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
5771   // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
5772   // end_index is exclusive).
5773   *end_index = std::min(
5774       input_dim,
5775       (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
5776 }
5777 
5778 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)5779 inline void BatchToSpaceND(
5780     const RuntimeShape& unextended_input1_shape, const T* input1_data,
5781     const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
5782     const RuntimeShape& unextended_input3_shape, const int32* crops_data,
5783     const RuntimeShape& unextended_output_shape, T* output_data) {
5784   gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
5785 
5786   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
5787   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5788   const RuntimeShape input1_shape =
5789       RuntimeShape::ExtendedShape(4, unextended_input1_shape);
5790   const RuntimeShape output_shape =
5791       RuntimeShape::ExtendedShape(4, unextended_output_shape);
5792 
5793   const int output_width = output_shape.Dims(2);
5794   const int output_height = output_shape.Dims(1);
5795   const int output_batch_size = output_shape.Dims(0);
5796 
5797   const int depth = input1_shape.Dims(3);
5798   const int input_width = input1_shape.Dims(2);
5799   const int input_height = input1_shape.Dims(1);
5800   const int input_batch_size = input1_shape.Dims(0);
5801 
5802   const int block_shape_width = block_shape_data[1];
5803   const int block_shape_height = block_shape_data[0];
5804   const int crops_top = crops_data[0];
5805   const int crops_left = crops_data[2];
5806 
5807   for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
5808     const int out_batch = in_batch % output_batch_size;
5809     const int spatial_offset = in_batch / output_batch_size;
5810 
5811     int in_h_start = 0;
5812     int in_h_end = 0;
5813     // GetIndexRange ensures start and end indices are in [0, output_height).
5814     GetIndexRange(spatial_offset / block_shape_width - crops_top,
5815                   block_shape_height, input_height, output_height, &in_h_start,
5816                   &in_h_end);
5817 
5818     for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
5819       const int out_h = in_h * block_shape_height +
5820                         spatial_offset / block_shape_width - crops_top;
5821       TFLITE_DCHECK_GE(out_h, 0);
5822       TFLITE_DCHECK_LT(out_h, output_height);
5823 
5824       int in_w_start = 0;
5825       int in_w_end = 0;
5826       // GetIndexRange ensures start and end indices are in [0, output_width).
5827       GetIndexRange(spatial_offset % block_shape_width - crops_left,
5828                     block_shape_width, input_width, output_width, &in_w_start,
5829                     &in_w_end);
5830 
5831       for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
5832         const int out_w = in_w * block_shape_width +
5833                           spatial_offset % block_shape_width - crops_left;
5834         TFLITE_DCHECK_GE(out_w, 0);
5835         TFLITE_DCHECK_LT(out_w, output_width);
5836         T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
5837         const T* in =
5838             input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
5839         memcpy(out, in, depth * sizeof(T));
5840       }
5841     }
5842   }
5843 }
5844 
5845 template <typename T>
TypedMemset(void * ptr,T value,size_t num)5846 void TypedMemset(void* ptr, T value, size_t num) {
5847   // Optimization for common cases where memset() will suffice.
5848   if (value == 0 || std::is_same<T, uint8_t>::value) {
5849     memset(ptr, value, num * sizeof(T));
5850   } else {
5851     // Default implementation for cases where memset() will not preserve the
5852     // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
5853     char* pos = static_cast<char*>(ptr);
5854     for (size_t i = 0; i < num; ++i) {
5855       memcpy(pos, &value, sizeof(T));
5856       pos = pos + sizeof(T);
5857     }
5858   }
5859 }
5860 
5861 // This makes heavy use of Offset, along with conditional branches. There may be
5862 // opportunities for improvement.
5863 //
5864 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
5865 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
5866 // equivalent to a simple input1_data.  For Pad, it should point to a zero
5867 // value.
5868 //
5869 // Note that two typenames are required, so that T=P=int32 is considered a
5870 // specialization distinct from P=int32.
5871 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5872 inline void PadImpl(const tflite::PadParams& op_params,
5873                     const RuntimeShape& input_shape, const T* input_data,
5874                     const P* pad_value_ptr, const RuntimeShape& output_shape,
5875                     T* output_data) {
5876   gemmlowp::ScopedProfilingLabel label("Pad4DSlowImpl");
5877   const RuntimeShape ext_input_shape =
5878       RuntimeShape::ExtendedShape(4, input_shape);
5879   const RuntimeShape ext_output_shape =
5880       RuntimeShape::ExtendedShape(4, output_shape);
5881   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
5882   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
5883 
5884   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
5885   // to 4 dims (yes, we are "padding the padding").
5886   std::vector<int> left_padding_copy(4, 0);
5887   const int left_padding_extend = 4 - op_params.left_padding_count;
5888   for (int i = 0; i < op_params.left_padding_count; ++i) {
5889     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
5890   }
5891   std::vector<int> right_padding_copy(4, 0);
5892   const int right_padding_extend = 4 - op_params.right_padding_count;
5893   for (int i = 0; i < op_params.right_padding_count; ++i) {
5894     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
5895   }
5896 
5897   const int output_batch = ext_output_shape.Dims(0);
5898   const int output_height = ext_output_shape.Dims(1);
5899   const int output_width = ext_output_shape.Dims(2);
5900   const int output_depth = ext_output_shape.Dims(3);
5901 
5902   const int left_b_padding = left_padding_copy[0];
5903   const int left_h_padding = left_padding_copy[1];
5904   const int left_w_padding = left_padding_copy[2];
5905   const int left_d_padding = left_padding_copy[3];
5906 
5907   const int right_b_padding = right_padding_copy[0];
5908   const int right_h_padding = right_padding_copy[1];
5909   const int right_w_padding = right_padding_copy[2];
5910   const int right_d_padding = right_padding_copy[3];
5911 
5912   const int input_depth = ext_input_shape.Dims(3);
5913   const T pad_value = *pad_value_ptr;
5914 
5915   if (left_b_padding != 0) {
5916     TypedMemset<T>(
5917         output_data, pad_value,
5918         left_b_padding * output_height * output_width * output_depth);
5919   }
5920   for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
5921        ++out_b) {
5922     if (left_h_padding != 0) {
5923       TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0),
5924                      pad_value, left_h_padding * output_width * output_depth);
5925     }
5926     for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
5927          ++out_h) {
5928       if (left_w_padding != 0) {
5929         TypedMemset<T>(
5930             output_data + Offset(ext_output_shape, out_b, out_h, 0, 0),
5931             pad_value, left_w_padding * output_depth);
5932       }
5933       for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
5934            ++out_w) {
5935         if (left_d_padding != 0) {
5936           TypedMemset<T>(
5937               output_data + Offset(ext_output_shape, out_b, out_h, out_w, 0),
5938               pad_value, left_d_padding);
5939         }
5940 
5941         T* out = output_data +
5942                  Offset(ext_output_shape, out_b, out_h, out_w, left_d_padding);
5943         const T* in = input_data +
5944                       Offset(ext_input_shape, out_b - left_b_padding,
5945                              out_h - left_h_padding, out_w - left_w_padding, 0);
5946         memcpy(out, in, input_depth * sizeof(T));
5947 
5948         if (right_d_padding != 0) {
5949           TypedMemset<T>(
5950               output_data + Offset(ext_output_shape, out_b, out_h, out_w,
5951                                    output_depth - right_d_padding),
5952               pad_value, right_d_padding);
5953         }
5954       }
5955       if (right_w_padding != 0) {
5956         TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_h,
5957                                             output_width - right_w_padding, 0),
5958                        pad_value, right_w_padding * output_depth);
5959       }
5960     }
5961     if (right_h_padding != 0) {
5962       TypedMemset<T>(
5963           output_data + Offset(ext_output_shape, out_b,
5964                                output_height - right_h_padding, 0, 0),
5965           pad_value, right_h_padding * output_width * output_depth);
5966     }
5967   }
5968   if (right_b_padding != 0) {
5969     TypedMemset<T>(
5970         output_data +
5971             Offset(ext_output_shape, output_batch - right_b_padding, 0, 0, 0),
5972         pad_value,
5973         right_b_padding * output_height * output_width * output_depth);
5974   }
5975 }
5976 
5977 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5978 inline void Pad(const tflite::PadParams& op_params,
5979                 const RuntimeShape& input_shape, const T* input_data,
5980                 const P* pad_value_ptr, const RuntimeShape& output_shape,
5981                 T* output_data) {
5982   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5983           output_data);
5984 }
5985 
5986 // The second (pad-value) input can be int32 when, say, the first is uint8.
5987 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5988 inline void Pad(const tflite::PadParams& op_params,
5989                 const RuntimeShape& input_shape, const T* input_data,
5990                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
5991                 T* output_data) {
5992   const T converted_pad_value = static_cast<T>(*pad_value_ptr);
5993   PadImpl(op_params, input_shape, input_data, &converted_pad_value,
5994           output_shape, output_data);
5995 }
5996 
5997 // This version avoids conflicting template matching.
5998 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)5999 inline void Pad(const tflite::PadParams& op_params,
6000                 const RuntimeShape& input_shape, const int32* input_data,
6001                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
6002                 int32* output_data) {
6003   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
6004           output_data);
6005 }
6006 
6007 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
6008 //
6009 // This pad requires that (a) left and right paddings are in the 4D patterns
6010 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
6011 // T is uint8.
6012 //
6013 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
6014 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
6015 // equivalent to a simple input1_data.  For Pad, it should point to a zero
6016 // value.
6017 //
6018 // Note that two typenames are required, so that T=P=int32 is considered a
6019 // specialization distinct from P=int32.
6020 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)6021 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
6022                                 const RuntimeShape& input_shape,
6023                                 const T* input_data, const P* pad_value_ptr,
6024                                 const RuntimeShape& output_shape,
6025                                 T* output_data) {
6026   gemmlowp::ScopedProfilingLabel label("PadImageStyle");
6027   const RuntimeShape ext_input_shape =
6028       RuntimeShape::ExtendedShape(4, input_shape);
6029   const RuntimeShape ext_output_shape =
6030       RuntimeShape::ExtendedShape(4, output_shape);
6031   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
6032   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
6033 
6034   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
6035   // to 4 dims (yes, we are "padding the padding").
6036   std::vector<int> left_padding_copy(4, 0);
6037   const int left_padding_extend = 4 - op_params.left_padding_count;
6038   for (int i = 0; i < op_params.left_padding_count; ++i) {
6039     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
6040   }
6041   std::vector<int> right_padding_copy(4, 0);
6042   const int right_padding_extend = 4 - op_params.right_padding_count;
6043   for (int i = 0; i < op_params.right_padding_count; ++i) {
6044     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
6045   }
6046   // The following padding restrictions are contractual requirements, and
6047   // embody what it means for a padding op to be "image-style".
6048   TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
6049   TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
6050   TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
6051   TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
6052 
6053   const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
6054   const int output_height = ext_output_shape.Dims(1);
6055   const int output_width = ext_output_shape.Dims(2);
6056   const int input_height = ext_input_shape.Dims(1);
6057   const int input_width = ext_input_shape.Dims(2);
6058   const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
6059 
6060   const int left_h_padding = left_padding_copy[1];
6061   const int left_w_padding = left_padding_copy[2];
6062   const int right_h_padding = right_padding_copy[1];
6063   const int right_w_padding = right_padding_copy[2];
6064 
6065   TFLITE_DCHECK_EQ(output_height,
6066                    input_height + left_h_padding + right_h_padding);
6067   TFLITE_DCHECK_EQ(output_width,
6068                    input_width + left_w_padding + right_w_padding);
6069 
6070   const T pad_value = *pad_value_ptr;
6071   const int top_block_size = left_h_padding * output_width * depth;
6072   const size_t num_top_block_bytes = top_block_size * sizeof(T);
6073   const int bottom_block_size = right_h_padding * output_width * depth;
6074   const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
6075   const int left_blocks_size = left_w_padding * depth;
6076   const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
6077   const int right_blocks_size = right_w_padding * depth;
6078   const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
6079   const int inner_line_size = input_width * depth;
6080   const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
6081 
6082   if (input_height == 0) {
6083     memset(output_data, pad_value,
6084            num_top_block_bytes + num_bottom_block_bytes);
6085   } else {
6086     for (int i = 0; i < batch; ++i) {
6087       // For each image in the batch, apply the top padding, then iterate
6088       // through rows, then apply the bottom padding.
6089       //
6090       // By unwinding one iteration, we can combine the first left-margin
6091       // padding with the top padding, and the last right-margin padding with
6092       // the bottom padding.
6093       memset(output_data, pad_value,
6094              num_top_block_bytes + num_left_block_bytes);
6095       output_data += top_block_size + left_blocks_size;
6096       memcpy(output_data, input_data, num_inner_line_bytes);
6097       input_data += inner_line_size;
6098       output_data += inner_line_size;
6099       // One iteration unwound.
6100       // Unwinding this loop affords the opportunity to reorder the loop work
6101       // and hence combine memset() calls.
6102       //
6103       // Before unwinding:
6104       // for (int j = 0; j < input_height; ++j) {
6105       //   // Pad on left, copy central data, pad on right.
6106       //   memset(output_data, pad_value, num_left_block_bytes);
6107       //   output_data += left_blocks_size;
6108       //   memcpy(output_data, input_data, num_inner_line_bytes);
6109       //   input_data += inner_line_size;
6110       //   output_data += inner_line_size;
6111       //   memset(output_data, pad_value, num_right_block_bytes);
6112       //   output_data += right_blocks_size;
6113       // }
6114       for (int j = 1; j < input_height; ++j) {
6115         memset(output_data, pad_value,
6116                num_right_block_bytes + num_left_block_bytes);
6117         output_data += right_blocks_size + left_blocks_size;
6118         memcpy(output_data, input_data, num_inner_line_bytes);
6119         input_data += inner_line_size;
6120         output_data += inner_line_size;
6121       }
6122       memset(output_data, pad_value,
6123              num_right_block_bytes + num_bottom_block_bytes);
6124       output_data += right_blocks_size + bottom_block_size;
6125     }
6126   }
6127 }
6128 
6129 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)6130 inline void PadImageStyle(const tflite::PadParams& op_params,
6131                           const RuntimeShape& input_shape, const T* input_data,
6132                           const P* pad_value_ptr,
6133                           const RuntimeShape& output_shape, T* output_data) {
6134   TFLITE_ASSERT_FALSE;
6135 }
6136 
6137 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)6138 inline void PadImageStyle(const tflite::PadParams& op_params,
6139                           const RuntimeShape& input_shape,
6140                           const uint8* input_data, const P* pad_value_ptr,
6141                           const RuntimeShape& output_shape,
6142                           uint8* output_data) {
6143   PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
6144                       output_shape, output_data);
6145 }
6146 
6147 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)6148 inline void PadImageStyle(const tflite::PadParams& op_params,
6149                           const RuntimeShape& input_shape,
6150                           const float* input_data, const P* pad_value_ptr,
6151                           const RuntimeShape& output_shape,
6152                           float* output_data) {
6153   const float converted_pad_value = static_cast<float>(*pad_value_ptr);
6154   if (converted_pad_value == 0.0f) {
6155     PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
6156                         output_shape, output_data);
6157   } else {
6158     PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
6159             output_data);
6160   }
6161 }
6162 
6163 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6164 inline void Slice(const tflite::SliceParams& op_params,
6165                   const RuntimeShape& input_shape, const T* input_data,
6166                   const RuntimeShape& output_shape, T* output_data) {
6167   gemmlowp::ScopedProfilingLabel label("Slice");
6168   const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
6169   // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
6170   TFLITE_DCHECK_LE(op_params.begin_count, 4);
6171   TFLITE_DCHECK_LE(op_params.size_count, 4);
6172   const int begin_count = op_params.begin_count;
6173   const int size_count = op_params.size_count;
6174   // We front-pad the begin and size vectors.
6175   const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
6176   const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
6177                          ? ext_shape.Dims(0) - start_b
6178                          : start_b + op_params.size[0];
6179   const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
6180   const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
6181                          ? ext_shape.Dims(1) - start_h
6182                          : start_h + op_params.size[size_count - 3];
6183   const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
6184   const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
6185                          ? ext_shape.Dims(2) - start_w
6186                          : start_w + op_params.size[size_count - 2];
6187   const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
6188   const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
6189                          ? ext_shape.Dims(3) - start_d
6190                          : start_d + op_params.size[size_count - 1];
6191 
6192   T* out_ptr = output_data;
6193   for (int in_b = start_b; in_b < stop_b; ++in_b) {
6194     for (int in_h = start_h; in_h < stop_h; ++in_h) {
6195       for (int in_w = start_w; in_w < stop_w; ++in_w) {
6196         const int len = stop_d - start_d;
6197         memcpy(out_ptr,
6198                input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
6199                len * sizeof(T));
6200         out_ptr += len;
6201       }
6202     }
6203   }
6204 }
6205 
6206 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6207 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
6208              const T* input2_data, const RuntimeShape& output_shape,
6209              T* output_data) {
6210   gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum");
6211   auto input1_map = MapAsVector(input1_data, input1_shape);
6212   auto output_map = MapAsVector(output_data, output_shape);
6213   auto min_value = input2_data[0];
6214   output_map.array() = input1_map.array().min(min_value);
6215 }
6216 
6217 // Convenience version that allows, for example, generated-code calls to be
6218 // the same as other binary ops.
6219 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6220 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
6221                     const RuntimeShape&, const T* input2_data,
6222                     const RuntimeShape& output_shape, T* output_data) {
6223   // Drop shape of second input: not needed.
6224   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
6225 }
6226 
6227 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6228 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
6229              const T* input2_data, const RuntimeShape& output_shape,
6230              T* output_data) {
6231   gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum");
6232   auto input1_map = MapAsVector(input1_data, input1_shape);
6233   auto output_map = MapAsVector(output_data, output_shape);
6234   auto max_value = input2_data[0];
6235   output_map.array() = input1_map.array().max(max_value);
6236 }
6237 
6238 // Convenience version that allows, for example, generated-code calls to be
6239 // the same as other binary ops.
6240 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6241 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
6242                     const RuntimeShape&, const T* input2_data,
6243                     const RuntimeShape& output_shape, T* output_data) {
6244   // Drop shape of second input: not needed.
6245   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
6246 }
6247 
6248 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)6249 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
6250                      const RuntimeShape& input_shape, const T* input_data,
6251                      const RuntimeShape& filter_shape,
6252                      const RuntimeShape& output_shape, T* im2col_data) {
6253   gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
6254   const int stride_width = params.stride_width;
6255   const int stride_height = params.stride_height;
6256   const int pad_width = params.padding_values.width;
6257   const int pad_height = params.padding_values.height;
6258   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
6259   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
6260   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
6261   TFLITE_DCHECK(im2col_data);
6262 
6263   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
6264   const int input_height = input_shape.Dims(1);
6265   const int input_width = input_shape.Dims(2);
6266   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
6267   const int filter_height = filter_shape.Dims(1);
6268   const int filter_width = filter_shape.Dims(2);
6269   const int output_height = output_shape.Dims(1);
6270   const int output_width = output_shape.Dims(2);
6271   MatchingDim(output_shape, 3, filter_shape, 0);  // output_depth
6272 
6273   // Construct the MxN sized im2col matrix.
6274   // The rows M, are sub-ordered B x H x W
6275   const RuntimeShape row_shape({1, batches, output_height, output_width});
6276   // The columns, N, are sub-ordered Kh x Kw x Din
6277   const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
6278   // Use dimensions M and N to construct dims for indexing directly into im2col
6279   const RuntimeShape im2col_shape(
6280       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
6281 
6282   // Build the im2col matrix by looping through all the input pixels,
6283   // computing their influence on the output, rather than looping through all
6284   // the output pixels. We therefore must initialize the im2col array to zero.
6285   // This is potentially inefficient because we subsequently overwrite bytes
6286   // set here. However, in practice memset is very fast and costs negligible.
6287   memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
6288 
6289   // Loop through the output batches
6290   for (int batch = 0; batch < batches; ++batch) {
6291     // Loop through input pixels one at a time.
6292     for (int in_y = 0; in_y < input_height; ++in_y) {
6293       for (int in_x = 0; in_x < input_width; ++in_x) {
6294         // Loop through the output pixels it will influence
6295         const int out_x_origin = (in_x * stride_width) - pad_width;
6296         const int out_y_origin = (in_y * stride_height) - pad_height;
6297         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
6298           const int out_y = out_y_origin + filter_y;
6299           // Is output pixel within height bounds?
6300           if ((out_y >= 0) && (out_y < output_height)) {
6301             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
6302               const int out_x = out_x_origin + filter_x;
6303               // Is output pixel within width bounds?
6304               if ((out_x >= 0) && (out_x < output_width)) {
6305                 // Copy the input elements of this pixel
6306                 T const* src =
6307                     input_data + Offset(input_shape, batch, in_y, in_x, 0);
6308                 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
6309                 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
6310                 T* dst = im2col_data +
6311                          Offset(im2col_shape, 0, 0, row_offset, col_offset);
6312                 memcpy(dst, src, input_depth * sizeof(T));
6313               }
6314             }
6315           }
6316         }
6317       }
6318     }
6319   }
6320 }
6321 
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)6322 inline void TransposeConv(
6323     const ConvParams& params, const RuntimeShape& input_shape,
6324     const float* input_data, const RuntimeShape& filter_shape,
6325     const float* filter_data, const RuntimeShape& output_shape,
6326     float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
6327   gemmlowp::ScopedProfilingLabel label("TransposeConv");
6328   // The complexity of the reference implementation is input.flat_size() *
6329   // filter.flat_size() / in_channel.
6330   //
6331   // While the complexity of im2col->gemm
6332   // implmentation is batch * output_height * output_width *
6333   // (filter.flat_size() / out_channel)^2 * out_channel.
6334   //
6335   // so if input.flat_size() * out_channel^2 is much smaller than
6336   // output.flat_size() * filter.size() * in_channel we should fall back to the
6337   // reference implementation.
6338   //
6339   // TODO(b/122331966): optimize the intuitive implementation.
6340   const int out_channel = output_shape.Dims(3);
6341   const int in_channel = input_shape.Dims(3);
6342   if ((input_shape.FlatSize() * out_channel * out_channel * 4) <
6343       (filter_shape.FlatSize() * output_shape.FlatSize() * in_channel)) {
6344     reference_ops::TransposeConv(params, input_shape, input_data, filter_shape,
6345                                  filter_data, output_shape, output_data,
6346                                  im2col_shape, im2col_data);
6347     return;
6348   }
6349   // Note we could use transposed weights with forward conv for unstrided
6350   // cases. But we are already getting good performance with this code as-is.
6351   TFLITE_DCHECK(im2col_data);
6352   TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
6353                   output_shape, im2col_data);
6354 
6355   const auto im2col_matrix_map =
6356       MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
6357   const auto filter_matrix_map =
6358       MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
6359   auto output_matrix_map =
6360       MapAsMatrixWithLastDimAsRows(output_data, output_shape);
6361 
6362   Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
6363 }
6364 
6365 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
6366 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
6367 // reference version. Debug checks are in place to test if this occurs.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)6368 inline void ResizeNearestNeighbor(
6369     const tflite::ResizeNearestNeighborParams& op_params,
6370     const RuntimeShape& unextended_input_shape, const uint8* input_data,
6371     const RuntimeShape& output_size_shape, const int32* output_size_data,
6372     const RuntimeShape& unextended_output_shape, uint8* output_data) {
6373   // Align corners = true is not supported.
6374   TFLITE_DCHECK(!op_params.align_corners);
6375   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
6376   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
6377 
6378   const RuntimeShape input_shape =
6379       RuntimeShape::ExtendedShape(4, unextended_input_shape);
6380   const RuntimeShape output_shape =
6381       RuntimeShape::ExtendedShape(4, unextended_output_shape);
6382 
6383   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
6384   int32 input_height = input_shape.Dims(1);
6385   int32 input_width = input_shape.Dims(2);
6386   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
6387 
6388   // The Tensorflow version of this op allows resize on the width and height
6389   // axis only.
6390   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
6391   int32 output_height = output_size_data[0];
6392   int32 output_width = output_size_data[1];
6393 
6394   // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
6395   // error factor and to avoid zero scales. For example, with input_height = 1,
6396   // output_height = 3, the float scaling factor would be non-zero at 1/3.
6397   // With fixed-point, this is zero.
6398   int32 height_scale = (input_height << 16) / output_height + 1;
6399   int32 width_scale = (input_width << 16) / output_width + 1;
6400 
6401   const int col_offset = input_shape.Dims(3);
6402   const int row_offset = input_shape.Dims(2) * col_offset;
6403   const int batch_offset = input_shape.Dims(1) * row_offset;
6404 
6405   const uint8* input_ptr = input_data;
6406   uint8* output_ptr = output_data;
6407   for (int b = 0; b < batches; ++b) {
6408     for (int y = 0; y < output_height; ++y) {
6409       int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
6410       // Check offset calculation is the same as the reference version. See
6411       // function comment for details. We check using a non-float version of:
6412       // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
6413       //                                            / output_height)));
6414       TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
6415       TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
6416       const uint8* y_input_ptr = input_ptr + in_y * row_offset;
6417       for (int x = 0; x < output_width; ++x) {
6418         int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
6419         // Check offset calculation is the same as the reference version. See
6420         // function comment for details. We check using a non-float version of:
6421         // TFLITE_DCHECK_EQ(in_y,
6422         //                  std::floor(y * (static_cast<float>(input_width)
6423         //                                      / output_width)));
6424         TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
6425         TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
6426         const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
6427         memcpy(output_ptr, x_input_ptr, depth);
6428         output_ptr += depth;
6429       }
6430     }
6431     input_ptr += batch_offset;
6432   }
6433 }
6434 
6435 }  // namespace optimized_ops
6436 }  // namespace tflite
6437 
6438 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
6439 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
6440 #pragma GCC diagnostic pop
6441 #endif
6442 
6443 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
6444