• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include "public/gemmlowp.h"
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
24 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
25 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
26 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30 
31 namespace tflite {
32 namespace optimized_ops {
33 
34 // Unoptimized reference ops:
35 using reference_ops::ArgMax;
36 using reference_ops::ArgMinMax;
37 using reference_ops::Broadcast4DSlowGreater;
38 using reference_ops::Broadcast4DSlowGreaterEqual;
39 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
40 using reference_ops::Broadcast4DSlowGreaterWithScaling;
41 using reference_ops::Broadcast4DSlowLess;
42 using reference_ops::Broadcast4DSlowLessEqual;
43 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
44 using reference_ops::Broadcast4DSlowLessWithScaling;
45 using reference_ops::BroadcastAdd4DSlow;
46 using reference_ops::BroadcastGreater;
47 using reference_ops::BroadcastGreaterEqual;
48 using reference_ops::BroadcastLess;
49 using reference_ops::BroadcastLessEqual;
50 using reference_ops::BroadcastMul4DSlow;
51 using reference_ops::BroadcastSubSlow;
52 using reference_ops::Concatenation;
53 using reference_ops::ConcatenationWithScaling;
54 using reference_ops::DepthConcatenation;
55 using reference_ops::Div;
56 using reference_ops::FakeQuant;
57 using reference_ops::Gather;
58 using reference_ops::Greater;
59 using reference_ops::GreaterEqual;
60 using reference_ops::GreaterEqualWithScaling;
61 using reference_ops::GreaterWithScaling;
62 using reference_ops::Less;
63 using reference_ops::LessEqual;
64 using reference_ops::LessEqualWithScaling;
65 using reference_ops::LessWithScaling;
66 using reference_ops::Mean;
67 using reference_ops::RankOneSelect;
68 using reference_ops::Relu1;
69 using reference_ops::Relu6;
70 using reference_ops::ReluX;
71 using reference_ops::Select;
72 using reference_ops::SpaceToBatchND;
73 using reference_ops::Split;
74 using reference_ops::TensorFlowSplit;
75 
76 static constexpr int kDepthwiseReverseShift = -1;
77 
78 template <typename Scalar, int N>
MapAsVector(Scalar * data,const Dims<N> & dims)79 VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
80   const int size = FlatSize(dims);
81   return VectorMap<Scalar>(data, size, 1);
82 }
83 
84 template <typename Scalar, int N>
MapAsMatrixWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)85 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
86                                                 const Dims<N>& dims) {
87   const int rows = dims.sizes[0];
88   int cols = 1;
89   for (int d = 1; d < N; d++) {
90     cols *= dims.sizes[d];
91   }
92   return MatrixMap<Scalar>(data, rows, cols);
93 }
94 
95 template <typename Scalar, int N>
MapAsMatrixWithLastDimAsCols(Scalar * data,const Dims<N> & dims)96 MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
97                                                const Dims<N>& dims) {
98   const int cols = dims.sizes[N - 1];
99   int rows = 1;
100   for (int d = 0; d < N - 1; d++) {
101     rows *= dims.sizes[d];
102   }
103   return MatrixMap<Scalar>(data, rows, cols);
104 }
105 
106 template <typename Scalar, int N>
MapAsArrayWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)107 ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
108                                               const Dims<N>& dims) {
109   const int rows = dims.sizes[0];
110   int cols = 1;
111   for (int d = 1; d < N; d++) {
112     cols *= dims.sizes[d];
113   }
114   return ArrayMap<Scalar>(data, rows, cols);
115 }
116 
117 // TODO(b/62193649): this function is only needed as long
118 // as we have the --variable_batch hack.
119 template <typename Scalar, int N>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const Dims<N> & dims,int rows)120 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
121                                                    const Dims<N>& dims,
122                                                    int rows) {
123   const int flatsize = FlatSize(dims);
124   TFLITE_DCHECK((flatsize % rows) == 0);
125   const int cols = flatsize / rows;
126   return MatrixMap<Scalar>(data, rows, cols);
127 }
128 
AreSameDims(const Dims<4> & dims1,const Dims<4> & dims2)129 inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
130   for (int i = 0; i < 4; i++) {
131     if (dims1.sizes[i] != dims2.sizes[i]) {
132       return false;
133     }
134   }
135   return true;
136 }
137 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)138 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
139                           const float* filter_data, const Dims<4>& filter_dims,
140                           const float* bias_data, const Dims<4>& bias_dims,
141                           int stride_width, int stride_height,
142                           int dilation_width_factor, int dilation_height_factor,
143                           int pad_width, int pad_height, int depth_multiplier,
144                           float output_activation_min,
145                           float output_activation_max, float* output_data,
146                           const Dims<4>& output_dims) {
147   tflite::DepthwiseParams op_params;
148   // Padding type is ignored, but still set.
149   op_params.padding_type = PaddingType::kSame;
150   op_params.padding_values.width = pad_width;
151   op_params.padding_values.height = pad_height;
152   op_params.stride_width = stride_width;
153   op_params.stride_height = stride_height;
154   op_params.dilation_width_factor = dilation_width_factor;
155   op_params.dilation_height_factor = dilation_height_factor;
156   op_params.depth_multiplier = depth_multiplier;
157   op_params.float_activation_min = output_activation_min;
158   op_params.float_activation_max = output_activation_max;
159 
160   const RuntimeShape output_shape = DimsToShape(output_dims);
161   const int output_height = output_shape.Dims(1);
162 
163   DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data,
164                     DimsToShape(filter_dims), filter_data,
165                     DimsToShape(bias_dims), bias_data, output_shape,
166                     output_data, CpuFlags(), /*thread_start=*/0,
167                     /*thread_end=*/output_height, /*thread_dim=*/1);
168 }
169 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)170 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
171                           const float* filter_data, const Dims<4>& filter_dims,
172                           const float* bias_data, const Dims<4>& bias_dims,
173                           int stride_width, int stride_height, int pad_width,
174                           int pad_height, int depth_multiplier,
175                           float output_activation_min,
176                           float output_activation_max, float* output_data,
177                           const Dims<4>& output_dims) {
178   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
179                 bias_dims, stride_width, stride_height, 1, 1, pad_width,
180                 pad_height, depth_multiplier, output_activation_min,
181                 output_activation_max, output_data, output_dims);
182 }
183 
184 // legacy, for compatibility with old checked-in code
185 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)186 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
187                    const float* filter_data, const Dims<4>& filter_dims,
188                    const float* bias_data, const Dims<4>& bias_dims,
189                    int stride_width, int stride_height, int pad_width,
190                    int pad_height, int depth_multiplier, float* output_data,
191                    const Dims<4>& output_dims) {
192   float output_activation_min, output_activation_max;
193   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
194   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
195                 bias_dims, stride_width, stride_height, pad_width, pad_height,
196                 depth_multiplier, output_activation_min, output_activation_max,
197                 output_data, output_dims);
198 }
199 
200 // legacy, for compatibility with old checked-in code
201 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)202 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
203                    const float* filter_data, const Dims<4>& filter_dims,
204                    const float* bias_data, const Dims<4>& bias_dims, int stride,
205                    int pad_width, int pad_height, int depth_multiplier,
206                    float* output_data, const Dims<4>& output_dims) {
207   DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
208                     bias_dims, stride, stride, pad_width, pad_height,
209                     depth_multiplier, output_data, output_dims);
210 }
211 
212 template <DepthwiseConvOutputRounding kOutputRounding>
LegacyDepthwiseConvWithRounding(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)213 inline void LegacyDepthwiseConvWithRounding(
214     const DepthwiseParams& params, const RuntimeShape& input_shape,
215     const uint8* input_data, const RuntimeShape& filter_shape,
216     const uint8* filter_data, const RuntimeShape& bias_shape,
217     const int32* bias_data, const RuntimeShape& output_shape,
218     uint8* output_data, int thread_start, int thread_end, int thread_dim) {
219   ruy::profiler::ScopeLabel label("DepthwiseConv/8bit");
220   const int depth_multiplier = params.depth_multiplier;
221   const int32 output_activation_min = params.quantized_activation_min;
222   const int32 output_activation_max = params.quantized_activation_max;
223   const int dilation_width_factor = params.dilation_width_factor;
224   const int dilation_height_factor = params.dilation_height_factor;
225   TFLITE_DCHECK_GE(dilation_width_factor, 1);
226   TFLITE_DCHECK_GE(dilation_height_factor, 1);
227   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
228   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
229   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
230   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
231   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
232   const int input_depth = input_shape.Dims(3);
233   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
234   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
235 
236 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
237 // Jetson TX-2. This compiler does not support the offsetof() macro.
238 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
239   const int stride_width = params.stride_width;
240   const int stride_height = params.stride_height;
241   const int pad_width = params.padding_values.width;
242   const int pad_height = params.padding_values.height;
243   const int output_shift = params.output_shift;
244 
245   // Call kernel optimized for depthwise convolutions using 3x3 filters if
246   // parameters are supported.
247   if (depthwise_conv::Fast3x3FilterKernelSupported(
248           input_shape, filter_shape, stride_width, stride_height,
249           dilation_width_factor, dilation_height_factor, pad_width, pad_height,
250           depth_multiplier, output_shape, output_shift)) {
251     ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/3x3");
252     depthwise_conv::DepthwiseConv3x3Filter<kOutputRounding>(
253         params, input_shape, input_data, filter_shape, filter_data, bias_shape,
254         bias_data, output_shape, output_data, thread_start, thread_end,
255         thread_dim);
256     return;
257   }
258 #endif
259 
260   ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/General");
261   depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data,
262                                        filter_shape, filter_data, bias_shape,
263                                        bias_data, output_shape, output_data,
264                                        thread_start, thread_end, thread_dim);
265 }
266 
LegacyDepthwiseConvImpl(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)267 inline void LegacyDepthwiseConvImpl(
268     const DepthwiseParams& params, const RuntimeShape& input_shape,
269     const uint8* input_data, const RuntimeShape& filter_shape,
270     const uint8* filter_data, const RuntimeShape& bias_shape,
271     const int32* bias_data, const RuntimeShape& output_shape,
272     uint8* output_data, int thread_start, int thread_end, int thread_dim) {
273   return LegacyDepthwiseConvWithRounding<
274       DepthwiseConvOutputRounding::kAwayFromZero>(
275       params, input_shape, input_data, filter_shape, filter_data, bias_shape,
276       bias_data, output_shape, output_data, thread_start, thread_end,
277       thread_dim);
278 }
279 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)280 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
281                           int32 input_offset, const uint8* filter_data,
282                           const Dims<4>& filter_dims, int32 filter_offset,
283                           const int32* bias_data, const Dims<4>& bias_dims,
284                           int stride_width, int stride_height,
285                           int dilation_width_factor, int dilation_height_factor,
286                           int pad_width, int pad_height, int depth_multiplier,
287                           int32 output_offset, int32 output_multiplier,
288                           int output_shift, int32 output_activation_min,
289                           int32 output_activation_max, uint8* output_data,
290                           const Dims<4>& output_dims) {
291   tflite::DepthwiseParams op_params;
292   // Padding type is ignored, but still set.
293   op_params.padding_type = PaddingType::kSame;
294   op_params.padding_values.width = pad_width;
295   op_params.padding_values.height = pad_height;
296   op_params.stride_width = stride_width;
297   op_params.stride_height = stride_height;
298   op_params.dilation_width_factor = dilation_width_factor;
299   op_params.dilation_height_factor = dilation_height_factor;
300   op_params.depth_multiplier = depth_multiplier;
301   op_params.quantized_activation_min = output_activation_min;
302   op_params.quantized_activation_max = output_activation_max;
303   op_params.input_offset = input_offset;
304   op_params.weights_offset = filter_offset;
305   op_params.output_offset = output_offset;
306   op_params.output_multiplier = output_multiplier;
307   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
308   op_params.output_shift = kDepthwiseReverseShift * output_shift;
309 
310   const RuntimeShape output_shape = DimsToShape(output_dims);
311   const int output_height = output_shape.Dims(1);
312 
313   LegacyDepthwiseConvImpl(
314       op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
315       filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
316       output_data, /*thread_start=*/0,
317       /*thread_end=*/output_height, /*thread_dim=*/1);
318 }
319 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)320 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
321                           int32 input_offset, const uint8* filter_data,
322                           const Dims<4>& filter_dims, int32 filter_offset,
323                           const int32* bias_data, const Dims<4>& bias_dims,
324                           int stride_width, int stride_height, int pad_width,
325                           int pad_height, int depth_multiplier,
326                           int32 output_offset, int32 output_multiplier,
327                           int output_shift, int32 output_activation_min,
328                           int32 output_activation_max, uint8* output_data,
329                           const Dims<4>& output_dims) {
330   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
331                 filter_offset, bias_data, bias_dims, stride_width,
332                 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
333                 output_offset, output_multiplier, output_shift,
334                 output_activation_min, output_activation_max, output_data,
335                 output_dims);
336 }
337 
338 // Legacy, for compatibility with old checked-in code.
339 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)340 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
341                    int32 input_offset, const uint8* filter_data,
342                    const Dims<4>& filter_dims, int32 filter_offset,
343                    const int32* bias_data, const Dims<4>& bias_dims,
344                    int stride_width, int stride_height, int pad_width,
345                    int pad_height, int depth_multiplier, int32 output_offset,
346                    int32 output_multiplier, int output_shift,
347                    int32 output_activation_min, int32 output_activation_max,
348                    uint8* output_data, const Dims<4>& output_dims) {
349   if (Ac == FusedActivationFunctionType::kNone) {
350     TFLITE_DCHECK_EQ(output_activation_min, 0);
351     TFLITE_DCHECK_EQ(output_activation_max, 255);
352   }
353   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
354                 filter_offset, bias_data, bias_dims, stride_width,
355                 stride_height, pad_width, pad_height, depth_multiplier,
356                 output_offset, output_multiplier, output_shift,
357                 output_activation_min, output_activation_max, output_data,
358                 output_dims);
359 }
360 
361 // Legacy, for compatibility with old checked-in code.
362 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)363 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
364                    int32 input_offset, const uint8* filter_data,
365                    const Dims<4>& filter_dims, int32 filter_offset,
366                    const int32* bias_data, const Dims<4>& bias_dims, int stride,
367                    int pad_width, int pad_height, int depth_multiplier,
368                    int32 output_offset, int32 output_multiplier,
369                    int output_shift, int32 output_activation_min,
370                    int32 output_activation_max, uint8* output_data,
371                    const Dims<4>& output_dims) {
372   DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
373                     filter_dims, filter_offset, bias_data, bias_dims, stride,
374                     stride, pad_width, pad_height, depth_multiplier,
375                     output_offset, output_multiplier, output_shift,
376                     output_activation_min, output_activation_max, output_data,
377                     output_dims);
378 }
379 
380 template <typename T, typename TS>
381 struct LegacyDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyDepthwiseConvWorkerTaskLegacyDepthwiseConvWorkerTask382   LegacyDepthwiseConvWorkerTask(
383       const DepthwiseParams& params, const RuntimeShape& input_shape,
384       const T* input_data, const RuntimeShape& filter_shape,
385       const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
386       const RuntimeShape& output_shape, T* output_data, int thread_start,
387       int thread_end, int thread_dim)
388       : params_(params),
389         input_shape_(input_shape),
390         input_data_(input_data),
391         filter_shape_(filter_shape),
392         filter_data_(filter_data),
393         bias_shape_(bias_shape),
394         bias_data_(bias_data),
395         output_shape_(output_shape),
396         output_data_(output_data),
397         thread_start_(thread_start),
398         thread_end_(thread_end),
399         thread_dim_(thread_dim) {}
400 
RunLegacyDepthwiseConvWorkerTask401   void Run() override {
402     LegacyDepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
403                             filter_data_, bias_shape_, bias_data_,
404                             output_shape_, output_data_, thread_start_,
405                             thread_end_, thread_dim_);
406   }
407 
408  private:
409   const DepthwiseParams& params_;
410   const RuntimeShape& input_shape_;
411   const T* input_data_;
412   const RuntimeShape& filter_shape_;
413   const T* filter_data_;
414   const RuntimeShape& bias_shape_;
415   const TS* bias_data_;
416   const RuntimeShape& output_shape_;
417   T* output_data_;
418   int thread_start_;
419   int thread_end_;
420   int thread_dim_;
421 };
422 
HowManyConvThreads(const RuntimeShape & output_shape,const RuntimeShape & filter_shape,int thread_dim)423 inline int HowManyConvThreads(const RuntimeShape& output_shape,
424                               const RuntimeShape& filter_shape,
425                               int thread_dim) {
426   constexpr int kMinMulPerThread = 8;
427   const int output_units = output_shape.Dims(thread_dim);
428   const int filter_height = filter_shape.Dims(1);
429   const int filter_width = filter_shape.Dims(2);
430   const int num_mul_per_unit =
431       FlatSizeSkipDim(output_shape, thread_dim) * filter_height * filter_width;
432   const int min_units_per_thread = kMinMulPerThread / num_mul_per_unit + 1;
433   int thread_count = output_units / min_units_per_thread;
434   return thread_count;
435 }
436 
437 inline void DepthwiseConv(
438     const DepthwiseParams& params, const RuntimeShape& input_shape,
439     const uint8* input_data, const RuntimeShape& filter_shape,
440     const uint8* filter_data, const RuntimeShape& bias_shape,
441     const int32* bias_data, const RuntimeShape& output_shape,
442     uint8* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) {
443   ruy::profiler::ScopeLabel label("DepthwiseConv");
444 
445   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
446   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
447   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
448 
449   const int output_batches = output_shape.Dims(0);
450   const int output_rows = output_shape.Dims(1);
451   int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
452   int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
453   int thread_dim, thread_count, thread_dim_size;
454   if (thread_count_batch > thread_count_row) {
455     thread_dim = 0;
456     thread_dim_size = output_batches;
457     thread_count = thread_count_batch;
458   } else {
459     thread_dim = 1;
460     thread_dim_size = output_rows;
461     thread_count = thread_count_row;
462   }
463 
464   const int max_threads =
465       gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
466   thread_count = std::max(1, std::min(thread_count, max_threads));
467 
468   if (thread_count == 1) {
469     LegacyDepthwiseConvImpl(params, input_shape, input_data, filter_shape,
470                             filter_data, bias_shape, bias_data, output_shape,
471                             output_data, /*thread_start=*/0,
472                             /*thread_end=*/output_rows, /*thread_dim=*/1);
473   } else {
474     std::vector<gemmlowp::Task*> tasks(thread_count);
475     int thread_start = 0;
476     for (int i = 0; i < thread_count; ++i) {
477       int thread_end =
478           thread_start + (thread_dim_size - thread_start) / (thread_count - i);
479       tasks[i] = new LegacyDepthwiseConvWorkerTask<uint8, int32>(
480           params, input_shape, input_data, filter_shape, filter_data,
481           bias_shape, bias_data, output_shape, output_data, thread_start,
482           thread_end, thread_dim);
483       thread_start = thread_end;
484     }
485     gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
486   }
487 }
488 
489 template <typename T, typename TS>
490 struct LegacyPerChannelDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyPerChannelDepthwiseConvWorkerTaskLegacyPerChannelDepthwiseConvWorkerTask491   LegacyPerChannelDepthwiseConvWorkerTask(
492       const DepthwiseParams& params, const int32* output_multiplier,
493       const int32* output_shift, const RuntimeShape& input_shape,
494       const T* input_data, const RuntimeShape& filter_shape,
495       const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
496       const RuntimeShape& output_shape, T* output_data, int thread_start,
497       int thread_end, int thread_dim)
498       : params_(params),
499         output_multiplier_(output_multiplier),
500         output_shift_(output_shift),
501         input_shape_(input_shape),
502         input_data_(input_data),
503         filter_shape_(filter_shape),
504         filter_data_(filter_data),
505         bias_shape_(bias_shape),
506         bias_data_(bias_data),
507         output_shape_(output_shape),
508         output_data_(output_data),
509         thread_start_(thread_start),
510         thread_end_(thread_end),
511         thread_dim_(thread_dim) {}
512 
RunLegacyPerChannelDepthwiseConvWorkerTask513   void Run() override {
514     CpuBackendContext backend_context;
515     optimized_integer_ops::DepthwiseConvImpl(
516         params_, output_multiplier_, output_shift_, input_shape_, input_data_,
517         filter_shape_, filter_data_, bias_shape_, bias_data_, output_shape_,
518         output_data_, thread_start_, thread_end_, thread_dim_, backend_context);
519   }
520 
521  private:
522   const DepthwiseParams& params_;
523   const int32* output_multiplier_;
524   const int32* output_shift_;
525   const RuntimeShape& input_shape_;
526   const T* input_data_;
527   const RuntimeShape& filter_shape_;
528   const T* filter_data_;
529   const RuntimeShape& bias_shape_;
530   const TS* bias_data_;
531   const RuntimeShape& output_shape_;
532   T* output_data_;
533   int thread_start_;
534   int thread_end_;
535   int thread_dim_;
536 };
537 
538 inline void DepthwiseConvPerChannel(
539     const DepthwiseParams& params, const int32* output_multiplier,
540     const int32* output_shift, const RuntimeShape& input_shape,
541     const int8* input_data, const RuntimeShape& filter_shape,
542     const int8* filter_data, const RuntimeShape& bias_shape,
543     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
544     gemmlowp::GemmContext* gemmlowp_context = nullptr) {
545   ruy::profiler::ScopeLabel label("DepthwiseConvInt8");
546 
547   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
548   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
549   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
550 
551   const int output_batches = output_shape.Dims(0);
552   const int output_rows = output_shape.Dims(1);
553   int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
554   int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
555   int thread_dim, thread_count, thread_dim_size;
556   if (thread_count_batch > thread_count_row) {
557     thread_dim = 0;
558     thread_dim_size = output_batches;
559     thread_count = thread_count_batch;
560   } else {
561     thread_dim = 1;
562     thread_dim_size = output_rows;
563     thread_count = thread_count_row;
564   }
565 
566   const int max_threads =
567       gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
568   thread_count = std::max(1, std::min(thread_count, max_threads));
569 
570   if (thread_count == 1) {
571     CpuBackendContext backend_context;
572     optimized_integer_ops::DepthwiseConvImpl(
573         params, output_multiplier, output_shift, input_shape, input_data,
574         filter_shape, filter_data, bias_shape, bias_data, output_shape,
575         output_data, /*thread_start=*/0,
576         /*thread_end=*/output_rows, /*thread_dim=*/1, backend_context);
577   } else {
578     std::vector<gemmlowp::Task*> tasks(thread_count);
579     int thread_start = 0;
580     for (int i = 0; i < thread_count; ++i) {
581       int thread_end =
582           thread_start + (thread_dim_size - thread_start) / (thread_count - i);
583       tasks[i] = new LegacyPerChannelDepthwiseConvWorkerTask<int8, int32>(
584           params, output_multiplier, output_shift, input_shape, input_data,
585           filter_shape, filter_data, bias_shape, bias_data, output_shape,
586           output_data, thread_start, thread_end, thread_dim);
587       thread_start = thread_end;
588     }
589     gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
590   }
591 }
592 
DepthwiseConv(const DepthwiseParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data)593 inline void DepthwiseConv(
594     const DepthwiseParams& params, const RuntimeShape& input_shape,
595     const float* input_data, const RuntimeShape& filter_shape,
596     const float* filter_data, const RuntimeShape& bias_shape,
597     const float* bias_data, const RuntimeShape& output_shape,
598     float* output_data) {
599   DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
600                     bias_shape, bias_data, output_shape, output_data,
601                     CpuFlags(),
602                     /*thread_start=*/0,
603                     /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
604 }
605 
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims,float output_activation_min,float output_activation_max)606 inline void AddBiasAndEvalActivationFunction(const float* bias_data,
607                                              const Dims<4>& bias_dims,
608                                              float* array_data,
609                                              const Dims<4>& array_dims,
610                                              float output_activation_min,
611                                              float output_activation_max) {
612   AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
613                                    DimsToShape(bias_dims), bias_data,
614                                    DimsToShape(array_dims), array_data);
615 }
616 
617 // legacy, for compatibility with old checked-in code
618 template <FusedActivationFunctionType Ac>
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims)619 void AddBiasAndEvalActivationFunction(const float* bias_data,
620                                       const Dims<4>& bias_dims,
621                                       float* array_data,
622                                       const Dims<4>& array_dims) {
623   float output_activation_min, output_activation_max;
624   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
625   AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
626                                    output_activation_min,
627                                    output_activation_max);
628 }
629 
630 template <typename Lhs, typename Rhs, typename Result>
Gemm(const Eigen::MatrixBase<Lhs> & lhs,const Eigen::MatrixBase<Rhs> & rhs,Eigen::MatrixBase<Result> * result)631 void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
632           Eigen::MatrixBase<Result>* result) {
633   if (rhs.cols() == 1) {
634     ruy::profiler::ScopeLabel label("GEMV");
635     result->col(0).noalias() = lhs * rhs.col(0);
636   } else {
637     ruy::profiler::ScopeLabel label("GEMM");
638     result->noalias() = lhs * rhs;
639   }
640 }
641 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data)642 inline void FullyConnected(
643     const FullyConnectedParams& params, const RuntimeShape& input_shape,
644     const float* input_data, const RuntimeShape& weights_shape,
645     const float* weights_data, const RuntimeShape& bias_shape,
646     const float* optional_bias_data, const RuntimeShape& output_shape,
647     float* output_data) {
648   ruy::profiler::ScopeLabel label("FullyConnected");
649   const float output_activation_min = params.float_activation_min;
650   const float output_activation_max = params.float_activation_max;
651 
652   // TODO(b/62193649): this convoluted shape computation (determining
653   // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
654   // is because the current --variable_batch hack consists in overwriting the
655   // 3rd dimension with the runtime batch size, as we don't keep track for each
656   // array of which dimension is the batch dimension in it.
657   // When that is fixed, this should become:
658   // const auto input_matrix_map =
659   //     MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
660   const int dims_count = weights_shape.DimensionsCount();
661   const int input_rows = weights_shape.Dims(dims_count - 1);
662   const auto input_matrix_map =
663       MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
664   const auto filter_matrix_map =
665       MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
666   auto output_matrix_map =
667       MapAsMatrixWithLastDimAsRows(output_data, output_shape);
668 
669   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
670 
671   if (optional_bias_data != nullptr) {
672     AddBiasAndEvalActivationFunction(
673         output_activation_min, output_activation_max, bias_shape,
674         optional_bias_data, output_shape, output_data);
675   } else {
676     const int flat_size = output_shape.FlatSize();
677     for (int i = 0; i < flat_size; ++i) {
678       output_data[i] = ActivationFunctionWithMinMax(
679           output_data[i], output_activation_min, output_activation_max);
680     }
681   }
682 }
683 
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)684 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
685                            const float* weights_data,
686                            const Dims<4>& weights_dims, const float* bias_data,
687                            const Dims<4>& bias_dims,
688                            float output_activation_min,
689                            float output_activation_max, float* output_data,
690                            const Dims<4>& output_dims) {
691   tflite::FullyConnectedParams op_params;
692   op_params.float_activation_min = output_activation_min;
693   op_params.float_activation_max = output_activation_max;
694 
695   FullyConnected(op_params, DimsToShape(input_dims), input_data,
696                  DimsToShape(weights_dims), weights_data,
697                  DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
698                  output_data);
699 }
700 
701 // legacy, for compatibility with old checked-in code
702 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)703 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
704                     const float* weights_data, const Dims<4>& weights_dims,
705                     const float* bias_data, const Dims<4>& bias_dims,
706                     float* output_data, const Dims<4>& output_dims) {
707   float output_activation_min, output_activation_max;
708   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
709   FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
710                  bias_dims, output_activation_min, output_activation_max,
711                  output_data, output_dims);
712 }
713 
714 struct GemmlowpOutputPipeline {
715   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
716       ColVectorMap;
717   typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
718                      gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
719                      gemmlowp::OutputStageClamp,
720                      gemmlowp::OutputStageSaturatingCastToUint8>
721       Pipeline;
MakeExpGemmlowpOutputPipeline722   static Pipeline MakeExp(const int32* bias_data, int output_rows,
723                           int32 output_offset, int32 output_multiplier,
724                           int output_left_shift, int32 output_activation_min,
725                           int32 output_activation_max) {
726     ColVectorMap bias_vector(bias_data, output_rows);
727     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
728     bias_addition_stage.bias_vector = bias_vector;
729     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
730     quantize_down_stage.result_offset_after_shift = output_offset;
731     quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
732     quantize_down_stage.result_exponent = output_left_shift;
733     gemmlowp::OutputStageClamp clamp_stage;
734     clamp_stage.min = output_activation_min;
735     clamp_stage.max = output_activation_max;
736     gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
737     return std::make_tuple(bias_addition_stage, quantize_down_stage,
738                            clamp_stage, saturating_cast_stage);
739   }
740 };
741 
742 struct GemmlowpOutputPipelineInt8 {
743   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
744       ColVectorMap;
745   typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
746                      gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
747                      gemmlowp::OutputStageClamp,
748                      gemmlowp::OutputStageSaturatingCastToInt8>
749       Pipeline;
MakeExpGemmlowpOutputPipelineInt8750   static Pipeline MakeExp(const int32* bias_data, int output_rows,
751                           int32 output_offset, int32 output_multiplier,
752                           int output_left_shift, int32 output_activation_min,
753                           int32 output_activation_max) {
754     ColVectorMap bias_vector(bias_data, output_rows);
755     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
756     bias_addition_stage.bias_vector = bias_vector;
757     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
758     quantize_down_stage.result_offset_after_shift = output_offset;
759     quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
760     quantize_down_stage.result_exponent = output_left_shift;
761     gemmlowp::OutputStageClamp clamp_stage;
762     clamp_stage.min = output_activation_min;
763     clamp_stage.max = output_activation_max;
764     gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage;
765     return std::make_tuple(bias_addition_stage, quantize_down_stage,
766                            clamp_stage, saturating_cast_stage);
767   }
768 };
769 
770 #ifdef USE_NEON
LegacyFullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,int row_start,int row_end)771 inline void LegacyFullyConnectedAsGEMVWorkerImpl(
772     const RuntimeShape& input_shape, const uint8* input_data,
773     int32 input_offset, const RuntimeShape& filter_shape,
774     const uint8* filter_data, int32 filter_offset,
775     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
776     int32 output_multiplier, int output_shift, int32 output_activation_min,
777     int32 output_activation_max, const RuntimeShape& output_shape,
778     uint8* output_data, int row_start, int row_end) {
779   ruy::profiler::ScopeLabel label("FullyConnectedAsGEMV/8bit");
780   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
781   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
782   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
783   const int output_dim_count = output_shape.DimensionsCount();
784   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
785   const int input_size = FlatSizeSkipDim(input_shape, 0);
786   static constexpr int kPeel = 4;
787   const bool shift_left = (output_shift > 0);
788   for (int k = 0; k < input_size; k += 64) {
789     optimized_ops_preload_l1_stream(input_data + k);
790   }
791   for (int k = 0; k < kPeel * input_size; k += 64) {
792     optimized_ops_preload_l1_stream(filter_data + k);
793   }
794 
795   TFLITE_DCHECK_GE(row_end - row_start, kPeel);
796 
797   for (int out = row_start; out < row_end; out += kPeel) {
798     out = std::min(out, row_end - kPeel);
799     int32x4_t acc0 = vdupq_n_s32(0);
800     int32x4_t acc1 = acc0;
801     int32x4_t acc2 = acc0;
802     int32x4_t acc3 = acc0;
803     const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
804     const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
805     int in = 0;
806     for (; in <= input_size - 16; in += 16) {
807       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
808       const uint8* filter_ptr = filter_data + in + out * input_size;
809       uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr);
810       optimized_ops_preload_l1_stream(filter_ptr + 64);
811       filter_ptr += input_size;
812       uint8x16_t filter_val_u8_1 = vld1q_u8(filter_ptr);
813       optimized_ops_preload_l1_stream(filter_ptr + 64);
814       filter_ptr += input_size;
815       uint8x16_t filter_val_u8_2 = vld1q_u8(filter_ptr);
816       optimized_ops_preload_l1_stream(filter_ptr + 64);
817       filter_ptr += input_size;
818       uint8x16_t filter_val_u8_3 = vld1q_u8(filter_ptr);
819       optimized_ops_preload_l1_stream(filter_ptr + 64);
820       int16x8_t input_val_0, input_val_1;
821       uint8x8_t low = vget_low_u8(input_val_u8);
822       uint8x8_t high = vget_high_u8(input_val_u8);
823       input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
824       input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
825       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
826       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
827       low = vget_low_u8(filter_val_u8_0);
828       high = vget_high_u8(filter_val_u8_0);
829       int16x8_t filter_val_0_0 = vreinterpretq_s16_u16(vmovl_u8(low));
830       int16x8_t filter_val_0_1 = vreinterpretq_s16_u16(vmovl_u8(high));
831       filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
832       filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
833       low = vget_low_u8(filter_val_u8_1);
834       high = vget_high_u8(filter_val_u8_1);
835       int16x8_t filter_val_1_0 = vreinterpretq_s16_u16(vmovl_u8(low));
836       int16x8_t filter_val_1_1 = vreinterpretq_s16_u16(vmovl_u8(high));
837       filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
838       filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
839       low = vget_low_u8(filter_val_u8_2);
840       high = vget_high_u8(filter_val_u8_2);
841       int16x8_t filter_val_2_0 = vreinterpretq_s16_u16(vmovl_u8(low));
842       int16x8_t filter_val_2_1 = vreinterpretq_s16_u16(vmovl_u8(high));
843       filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
844       filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
845       low = vget_low_u8(filter_val_u8_3);
846       high = vget_high_u8(filter_val_u8_3);
847       int16x8_t filter_val_3_0 = vreinterpretq_s16_u16(vmovl_u8(low));
848       int16x8_t filter_val_3_1 = vreinterpretq_s16_u16(vmovl_u8(high));
849       filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
850       filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
851       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
852                        vget_low_s16(input_val_0));
853       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
854                        vget_low_s16(input_val_0));
855       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
856                        vget_low_s16(input_val_0));
857       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
858                        vget_low_s16(input_val_0));
859       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
860                        vget_low_s16(input_val_1));
861       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
862                        vget_low_s16(input_val_1));
863       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
864                        vget_low_s16(input_val_1));
865       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
866                        vget_low_s16(input_val_1));
867       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
868                        vget_high_s16(input_val_0));
869       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
870                        vget_high_s16(input_val_0));
871       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
872                        vget_high_s16(input_val_0));
873       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
874                        vget_high_s16(input_val_0));
875       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
876                        vget_high_s16(input_val_1));
877       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
878                        vget_high_s16(input_val_1));
879       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
880                        vget_high_s16(input_val_1));
881       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
882                        vget_high_s16(input_val_1));
883     }
884     for (; in <= input_size - 8; in += 8) {
885       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
886       const uint8* filter_ptr = filter_data + in + out * input_size;
887       uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr);
888       filter_ptr += input_size;
889       uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr);
890       filter_ptr += input_size;
891       uint8x8_t filter_val_u8_2 = vld1_u8(filter_ptr);
892       filter_ptr += input_size;
893       uint8x8_t filter_val_u8_3 = vld1_u8(filter_ptr);
894       int16x8_t input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
895       input_val = vaddq_s16(input_val, input_offset_vec);
896       int16x8_t filter_val_0 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_0));
897       filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
898       int16x8_t filter_val_1 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_1));
899       filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
900       int16x8_t filter_val_2 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_2));
901       filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
902       int16x8_t filter_val_3 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_3));
903       filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
904       acc0 =
905           vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
906       acc1 =
907           vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
908       acc2 =
909           vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
910       acc3 =
911           vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
912       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
913                        vget_high_s16(input_val));
914       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
915                        vget_high_s16(input_val));
916       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
917                        vget_high_s16(input_val));
918       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
919                        vget_high_s16(input_val));
920     }
921     if (in < input_size) {
922       int32 buf[16];
923       vst1q_s32(buf + 0, acc0);
924       vst1q_s32(buf + 4, acc1);
925       vst1q_s32(buf + 8, acc2);
926       vst1q_s32(buf + 12, acc3);
927       for (; in < input_size; in++) {
928         int lane = (in + 8 - input_size) % 4;
929         const int32 input_val = input_data[in] + input_offset;
930         for (int k = 0; k < kPeel; k++) {
931           int32 filter_val =
932               filter_data[in + (out + k) * input_size] + filter_offset;
933           buf[lane + 4 * k] += filter_val * input_val;
934         }
935       }
936       acc0 = vld1q_s32(buf + 0);
937       acc1 = vld1q_s32(buf + 4);
938       acc2 = vld1q_s32(buf + 8);
939       acc3 = vld1q_s32(buf + 12);
940     }
941 
942     // Horizontally reduce accumulators
943     int32x2_t pairwise_reduced_acc_0 =
944         vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
945     int32x2_t pairwise_reduced_acc_1 =
946         vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
947     int32x2_t pairwise_reduced_acc_2 =
948         vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
949     int32x2_t pairwise_reduced_acc_3 =
950         vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
951     const int32x2_t reduced_lo =
952         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
953     const int32x2_t reduced_hi =
954         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
955     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
956     // Add bias values.
957     int32x4_t bias_vec = vld1q_s32(bias_data + out);
958     reduced = vaddq_s32(reduced, bias_vec);
959     if (shift_left) {
960       const int32 multiplier_power_of_two = 1 << output_shift;
961       reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
962       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
963     } else {
964       // Multiply by the fixed-point multiplier.
965       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
966       // Rounding-shift-right.
967       using gemmlowp::RoundingDivideByPOT;
968       reduced = RoundingDivideByPOT(reduced, -output_shift);
969     }
970     // Add the output offset.
971     const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
972     reduced = vaddq_s32(reduced, output_offset_vec);
973     // Narrow values down to 16 bit signed.
974     const int16x4_t res16 = vqmovn_s32(reduced);
975     // Narrow values down to 8 bit unsigned, saturating.
976     uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
977     // Apply the clamping from the activation function
978     res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
979     res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
980     // Store results to destination.
981     vst1_lane_u8(output_data + out + 0, res8, 0);
982     vst1_lane_u8(output_data + out + 1, res8, 1);
983     vst1_lane_u8(output_data + out + 2, res8, 2);
984     vst1_lane_u8(output_data + out + 3, res8, 3);
985   }
986 }
987 
988 struct LegacyFullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyFullyConnectedAsGEMVWorkerTaskLegacyFullyConnectedAsGEMVWorkerTask989   LegacyFullyConnectedAsGEMVWorkerTask(
990       const RuntimeShape& input_shape, const uint8* input_data,
991       int32 input_offset, const RuntimeShape& filter_shape,
992       const uint8* filter_data, int32 filter_offset,
993       const RuntimeShape& bias_shape, const int32* bias_data,
994       int32 output_offset, int32 output_multiplier, int output_shift,
995       int32 output_activation_min, int32 output_activation_max,
996       const RuntimeShape& output_shape, uint8* output_data, int row_start,
997       int row_end)
998       : input_shape_(input_shape),
999         input_data_(input_data),
1000         input_offset_(input_offset),
1001         filter_shape_(filter_shape),
1002         filter_data_(filter_data),
1003         filter_offset_(filter_offset),
1004         bias_shape_(bias_shape),
1005         bias_data_(bias_data),
1006         output_offset_(output_offset),
1007         output_multiplier_(output_multiplier),
1008         output_shift_(output_shift),
1009         output_activation_min_(output_activation_min),
1010         output_activation_max_(output_activation_max),
1011         output_shape_(output_shape),
1012         output_data_(output_data),
1013         row_start_(row_start),
1014         row_end_(row_end) {}
1015 
RunLegacyFullyConnectedAsGEMVWorkerTask1016   void Run() override {
1017     LegacyFullyConnectedAsGEMVWorkerImpl(
1018         input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
1019         filter_offset_, bias_shape_, bias_data_, output_offset_,
1020         output_multiplier_, output_shift_, output_activation_min_,
1021         output_activation_max_, output_shape_, output_data_, row_start_,
1022         row_end_);
1023   }
1024 
1025   const RuntimeShape& input_shape_;
1026   const uint8* input_data_;
1027   int32 input_offset_;
1028   const RuntimeShape& filter_shape_;
1029   const uint8* filter_data_;
1030   int32 filter_offset_;
1031   const RuntimeShape& bias_shape_;
1032   const int32* bias_data_;
1033   int32 output_offset_;
1034   int32 output_multiplier_;
1035   int output_shift_;
1036   int32 output_activation_min_;
1037   int32 output_activation_max_;
1038   const RuntimeShape& output_shape_;
1039   uint8* output_data_;
1040   int row_start_;
1041   int row_end_;
1042 };
1043 
FullyConnectedAsGEMV(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1044 inline void FullyConnectedAsGEMV(
1045     const RuntimeShape& input_shape, const uint8* input_data,
1046     int32 input_offset, const RuntimeShape& filter_shape,
1047     const uint8* filter_data, int32 filter_offset,
1048     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1049     int32 output_multiplier, int output_shift, int32 output_activation_min,
1050     int32 output_activation_max, const RuntimeShape& output_shape,
1051     uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1052   const int output_dim_count = output_shape.DimensionsCount();
1053   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1054   const int output_rows = output_shape.Dims(output_dim_count - 1);
1055   const int input_size = FlatSizeSkipDim(input_shape, 0);
1056   static constexpr int kKernelRows = 4;
1057   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1058       gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
1059   if (thread_count == 1) {
1060     // Single-thread case: do the computation on the current thread, don't
1061     // use a threadpool
1062     LegacyFullyConnectedAsGEMVWorkerImpl(
1063         input_shape, input_data, input_offset, filter_shape, filter_data,
1064         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1065         output_shift, output_activation_min, output_activation_max,
1066         output_shape, output_data, 0, output_rows);
1067     return;
1068   }
1069 
1070   // Multi-threaded case: use the gemmlowp context's threadpool.
1071   TFLITE_DCHECK_GT(thread_count, 1);
1072   std::vector<gemmlowp::Task*> tasks(thread_count);
1073   const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
1074       gemmlowp::CeilQuotient(output_rows, thread_count));
1075   int row_start = 0;
1076   for (int i = 0; i < thread_count; ++i) {
1077     int row_end = std::min(output_rows, row_start + kRowsPerWorker);
1078     tasks[i] = new LegacyFullyConnectedAsGEMVWorkerTask(
1079         input_shape, input_data, input_offset, filter_shape, filter_data,
1080         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1081         output_shift, output_activation_min, output_activation_max,
1082         output_shape, output_data, row_start, row_end);
1083     row_start = row_end;
1084   }
1085   TFLITE_DCHECK_EQ(row_start, output_rows);
1086   gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
1087 }
1088 #endif  // USE_NEON
1089 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1090 inline void FullyConnected(
1091     const FullyConnectedParams& params, const RuntimeShape& input_shape,
1092     const uint8* input_data, const RuntimeShape& filter_shape,
1093     const uint8* filter_data, const RuntimeShape& bias_shape,
1094     const int32* bias_data, const RuntimeShape& output_shape,
1095     uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1096   ruy::profiler::ScopeLabel label("FullyConnected/8bit");
1097   const int32 input_offset = params.input_offset;
1098   const int32 filter_offset = params.weights_offset;
1099   const int32 output_offset = params.output_offset;
1100   const int32 output_multiplier = params.output_multiplier;
1101   const int output_shift = params.output_shift;
1102   const int32 output_activation_min = params.quantized_activation_min;
1103   const int32 output_activation_max = params.quantized_activation_max;
1104   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1105   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1106   // TODO(b/62193649): This really should be:
1107   //     const int batches = ArraySize(output_dims, 1);
1108   // but the current --variable_batch hack consists in overwriting the 3rd
1109   // dimension with the runtime batch size, as we don't keep track for each
1110   // array of which dimension is the batch dimension in it.
1111   const int output_dim_count = output_shape.DimensionsCount();
1112   const int filter_dim_count = filter_shape.DimensionsCount();
1113   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1114 #ifdef USE_NEON
1115   if (batches == 1) {
1116     const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
1117                                         output_shape, output_dim_count - 1);
1118     if (output_size >= 4) {
1119       return FullyConnectedAsGEMV(
1120           input_shape, input_data, input_offset, filter_shape, filter_data,
1121           filter_offset, bias_shape, bias_data, output_offset,
1122           output_multiplier, output_shift, output_activation_min,
1123           output_activation_max, output_shape, output_data, gemmlowp_context);
1124     }
1125   }
1126 #endif  // USE_NEON
1127   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
1128   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
1129   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
1130   const int output_rows = output_shape.Dims(output_dim_count - 1);
1131   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1132   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1133 
1134   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
1135       filter_data, output_rows, filter_cols, filter_cols);
1136   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1137       input_data, filter_cols, batches, filter_cols);
1138   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
1139       output_data, output_rows, batches, output_rows);
1140   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
1141       bias_data, output_rows, output_offset, output_multiplier, output_shift,
1142       output_activation_min, output_activation_max);
1143   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
1144                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1145       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
1146       filter_offset, input_offset, output_pipeline);
1147 }
1148 
1149 #ifdef GEMMLOWP_NEON
1150 // In the common case of batch size 1, a fully-connected node degenerates
1151 // to a matrix*vector product. LSTM cells contain a fully-connected node;
1152 // when quantized, this becomes a special type of GEMV operation where
1153 // the output is 16bit-quantized, thus needs its own special path.
GEMVForLstmCell(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,uint8 weights_zero_point,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1154 inline void GEMVForLstmCell(const RuntimeShape& input_shape,
1155                             const uint8* input_data,
1156                             const RuntimeShape& weights_shape,
1157                             const uint8* weights_data, uint8 weights_zero_point,
1158                             const RuntimeShape& bias_shape,
1159                             const int32* bias_data, int32 accum_multiplier,
1160                             int accum_shift, const RuntimeShape& output_shape,
1161                             int16* output_data) {
1162   ruy::profiler::ScopeLabel label("GEMVForLstmCell");
1163   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1164   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1165   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1166   const int output_dim_count = output_shape.DimensionsCount();
1167   const int weights_dim_count = weights_shape.DimensionsCount();
1168   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1169   const int input_size = FlatSizeSkipDim(input_shape, 0);
1170   const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1171                                       output_shape, output_dim_count - 1);
1172   // This special fast path for quantized LSTM cells does not try to support
1173   // odd sizes that we haven't encountered in any LSTM cell, that would
1174   // require special code (that would go untested until any LSTM cell
1175   // exercises it). We just guard our assumptions about size evenness with
1176   // the following assertions.
1177   TFLITE_DCHECK(!(output_size % 4));
1178   TFLITE_DCHECK(!(input_size % 8));
1179   const int32* bias_ptr = bias_data;
1180   int16* output_ptr = output_data;
1181   for (int out = 0; out < output_size; out += 4) {
1182     int32x4_t acc_0 = vdupq_n_s32(0);
1183     int32x4_t acc_1 = vdupq_n_s32(0);
1184     int32x4_t acc_2 = vdupq_n_s32(0);
1185     int32x4_t acc_3 = vdupq_n_s32(0);
1186     const int16x8_t input_offset_vec = vdupq_n_s16(-128);
1187     const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point);
1188     int in = 0;
1189     // Handle 16 levels of depth at a time.
1190     for (; in <= input_size - 16; in += 16) {
1191       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
1192       const uint8* weights_ptr = weights_data + in + out * input_size;
1193       uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size);
1194       uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size);
1195       uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size);
1196       uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size);
1197       int16x8_t input_val_0, input_val_1;
1198       const uint8x8_t low = vget_low_u8(input_val_u8);
1199       const uint8x8_t high = vget_high_u8(input_val_u8);
1200       input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
1201       input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
1202       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1203       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1204       int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0,
1205           weights_val_3_0;
1206       int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1,
1207           weights_val_3_1;
1208       weights_val_0_0 = vaddq_s16(
1209           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))),
1210           weights_offset_vec);
1211       weights_val_0_1 = vaddq_s16(
1212           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))),
1213           weights_offset_vec);
1214       weights_val_1_0 = vaddq_s16(
1215           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))),
1216           weights_offset_vec);
1217       weights_val_1_1 = vaddq_s16(
1218           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))),
1219           weights_offset_vec);
1220       weights_val_2_0 = vaddq_s16(
1221           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))),
1222           weights_offset_vec);
1223       weights_val_2_1 = vaddq_s16(
1224           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))),
1225           weights_offset_vec);
1226       weights_val_3_0 = vaddq_s16(
1227           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))),
1228           weights_offset_vec);
1229       weights_val_3_1 = vaddq_s16(
1230           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))),
1231           weights_offset_vec);
1232       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0),
1233                         vget_low_s16(input_val_0));
1234       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0),
1235                         vget_low_s16(input_val_0));
1236       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0),
1237                         vget_low_s16(input_val_0));
1238       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0),
1239                         vget_low_s16(input_val_0));
1240       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0),
1241                         vget_high_s16(input_val_0));
1242       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0),
1243                         vget_high_s16(input_val_0));
1244       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0),
1245                         vget_high_s16(input_val_0));
1246       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0),
1247                         vget_high_s16(input_val_0));
1248       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1),
1249                         vget_low_s16(input_val_1));
1250       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1),
1251                         vget_low_s16(input_val_1));
1252       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1),
1253                         vget_low_s16(input_val_1));
1254       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1),
1255                         vget_low_s16(input_val_1));
1256       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1),
1257                         vget_high_s16(input_val_1));
1258       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1),
1259                         vget_high_s16(input_val_1));
1260       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1),
1261                         vget_high_s16(input_val_1));
1262       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1),
1263                         vget_high_s16(input_val_1));
1264     }
1265     // Handle 8 levels of depth at a time.
1266     for (; in < input_size; in += 8) {
1267       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
1268       const uint8* weights_ptr = weights_data + in + out * input_size;
1269       uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size);
1270       uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size);
1271       uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size);
1272       uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size);
1273       int16x8_t input_val;
1274       input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
1275       input_val = vaddq_s16(input_val, input_offset_vec);
1276       int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3;
1277       weights_val_0 =
1278           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)),
1279                     weights_offset_vec);
1280       weights_val_1 =
1281           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)),
1282                     weights_offset_vec);
1283       weights_val_2 =
1284           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)),
1285                     weights_offset_vec);
1286       weights_val_3 =
1287           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)),
1288                     weights_offset_vec);
1289       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0),
1290                         vget_low_s16(input_val));
1291       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1),
1292                         vget_low_s16(input_val));
1293       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2),
1294                         vget_low_s16(input_val));
1295       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3),
1296                         vget_low_s16(input_val));
1297       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0),
1298                         vget_high_s16(input_val));
1299       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1),
1300                         vget_high_s16(input_val));
1301       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2),
1302                         vget_high_s16(input_val));
1303       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3),
1304                         vget_high_s16(input_val));
1305     }
1306     // Horizontally reduce accumulators
1307     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1308         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1309     pairwise_reduced_acc_0 =
1310         vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
1311     pairwise_reduced_acc_1 =
1312         vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
1313     pairwise_reduced_acc_2 =
1314         vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
1315     pairwise_reduced_acc_3 =
1316         vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
1317     const int32x2_t reduced_lo =
1318         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1319     const int32x2_t reduced_hi =
1320         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1321     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1322     // Add bias values.
1323     int32x4_t bias_vec = vld1q_s32(bias_ptr);
1324     bias_ptr += 4;
1325     reduced = vaddq_s32(reduced, bias_vec);
1326     int left_shift = accum_shift > 0 ? accum_shift : 0;
1327     int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1328     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1329     // Multiply by the fixed-point multiplier.
1330     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1331     // Rounding-shift-right.
1332     using gemmlowp::RoundingDivideByPOT;
1333     reduced = RoundingDivideByPOT(reduced, right_shift);
1334     // Narrow values down to 16 bit signed.
1335     const int16x4_t res16 = vqmovn_s32(reduced);
1336     vst1_s16(output_ptr, res16);
1337     output_ptr += 4;
1338   }
1339 }
1340 #endif
1341 
1342 #ifdef GEMMLOWP_NEON
GEMVForLstmCellWithSymmetricRange(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1343 inline void GEMVForLstmCellWithSymmetricRange(
1344     const RuntimeShape& input_shape, const uint8* input_data,
1345     const RuntimeShape& weights_shape, const uint8* weights_data,
1346     const RuntimeShape& bias_shape, const int32* bias_data,
1347     int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
1348     int16* output_data) {
1349   ruy::profiler::ScopeLabel label("GEMVForLstmCellWithSymmetricRange");
1350   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1351   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1352   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1353   const int output_dim_count = output_shape.DimensionsCount();
1354   const int weights_dim_count = weights_shape.DimensionsCount();
1355   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1356   const int input_size = FlatSizeSkipDim(input_shape, 0);
1357   const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1358                                       output_shape, output_dim_count - 1);
1359   // This special fast path for quantized LSTM cells does not try to support
1360   // odd sizes that we haven't encountered in any LSTM cell, that would
1361   // require special code (that would go untested until any LSTM cell
1362   // exercises it). We just guard our assumptions about size evenness with
1363   // the following assertions.
1364   TFLITE_DCHECK(!(output_size % 4));
1365   TFLITE_DCHECK(!(input_size % 64));
1366   const int32* bias_ptr = bias_data;
1367   int16* output_ptr = output_data;
1368   const uint8x16_t signbit = vdupq_n_u8(0x80);
1369   for (int in = 0; in < input_size; in += 32) {
1370     optimized_ops_preload_l1_keep(input_data + in);
1371   }
1372   const int left_shift = accum_shift > 0 ? accum_shift : 0;
1373   const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1374   for (int out = 0; out < output_size; out += 4) {
1375     // Load the bias values
1376     int32x4_t bias_vec = vld1q_s32(bias_ptr);
1377     bias_ptr += 4;
1378 
1379     // Clear accumulators. We use 2 accumulator registers per row,
1380     // for 4 rows. row_accumRN is the N-th accumulator for row R.
1381     int32x4_t row_accum00 = vdupq_n_s32(0);
1382     int32x4_t row_accum01 = vdupq_n_s32(0);
1383     int32x4_t row_accum10 = vdupq_n_s32(0);
1384     int32x4_t row_accum11 = vdupq_n_s32(0);
1385     int32x4_t row_accum20 = vdupq_n_s32(0);
1386     int32x4_t row_accum21 = vdupq_n_s32(0);
1387     int32x4_t row_accum30 = vdupq_n_s32(0);
1388     int32x4_t row_accum31 = vdupq_n_s32(0);
1389 
1390     // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
1391     const int kReadAhead = 512;
1392     // Prefetch the first weights values.
1393     for (int k = 0; k < kReadAhead; k += 64) {
1394       optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1395                                       k);
1396       optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1397                                       k);
1398       optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1399                                       k);
1400       optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1401                                       k);
1402     }
1403     // Loop along the rows, handling 64 bytes per iteration because that's
1404     // cache line size on most current ARM-architecture CPUs.
1405     for (int in = 0; in < input_size; in += 64) {
1406       // Prefetch some future weights values.
1407       optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1408                                       in + kReadAhead);
1409       optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1410                                       in + kReadAhead);
1411       optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1412                                       in + kReadAhead);
1413       optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1414                                       in + kReadAhead);
1415 
1416       // We will use 2 local 16-bit accumulators per row, for 2 rows.
1417       // See below (*) for the rationale of processing only 2 rows at a time.
1418       // local_accumRN is the N-th local accumulator for row R.
1419       int16x8_t local_accum00;
1420       int16x8_t local_accum01;
1421       int16x8_t local_accum10;
1422       int16x8_t local_accum11;
1423 
1424       // Load 64 bytes of input activations values. Convert to signed int8
1425       // by flipping the sign bit (i.e. subtracting 128, the required
1426       // zero_point value).
1427       int8x16_t input0 = vreinterpretq_s8_u8(
1428           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
1429       int8x16_t input1 = vreinterpretq_s8_u8(
1430           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
1431       int8x16_t input2 = vreinterpretq_s8_u8(
1432           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
1433       int8x16_t input3 = vreinterpretq_s8_u8(
1434           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
1435 
1436       // Beginning of the core accumulation. Notice how while we have 4
1437       // rows to process, this code is taking care of only 2 rows at a time,
1438       // thus being divided into two parts looking similar ("Rows 0 and 1" and
1439       // "Rows 2 and 3").
1440       //
1441       // (*) The rationale for handling only 2 rows at a time is to avoid
1442       // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
1443       // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
1444       // we may find ourselves in a situation where rows alias each other in
1445       // the L1 cache, and moreover may also mutually alias with the input
1446       // activations. If we try to load 4 rows at a time, together with the
1447       // input activations, that may be 5 mutually-aliasing vectors, resulting
1448       // in constant mutual eviction from L1 cache. Handling 2 rows at a time
1449       // here largely mitigates these issues, and seems at least to be very
1450       // effective on Cortex-A53:
1451       //                          Before       After
1452       // big (Cortex-A73)         2.85 ms      2.85 ms
1453       // little (Cortex-A53)      11.0 ms      5.16 ms
1454 
1455       // Rows 0 and 1:
1456       // Load 64 bytes of weights values from each row. Convert to signed int8
1457       // by flipping the sign bit (i.e. subtracting 128, the required
1458       // zero_point value).
1459       int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
1460           signbit,
1461           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
1462       int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
1463           signbit,
1464           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
1465       int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
1466           signbit,
1467           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
1468       int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
1469           signbit,
1470           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
1471       int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
1472           signbit,
1473           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
1474       int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
1475           signbit,
1476           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
1477       int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
1478           signbit,
1479           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
1480       int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
1481           signbit,
1482           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
1483       // Multiply-accumulate into local 16-bit accumulators.
1484       // We can accumulate two products without overflow because weights are
1485       // required to never be -128, so each product is at most 127^2 in absolute
1486       // value.
1487       local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1488       local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1489       local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1490       local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1491       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1492                                vget_high_s8(input0));
1493       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1494                                vget_high_s8(input1));
1495       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1496                                vget_high_s8(input0));
1497       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1498                                vget_high_s8(input1));
1499       // Pairwise add and accumulate into 32-bit accumulators
1500       row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1501       row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1502       row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1503       row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1504       // Multiply-accumulate into local 16-bit accumulators.
1505       // We can accumulate two products without overflow because weights are
1506       // required to never be -128, so each product is at most 127^2 in absolute
1507       // value.
1508       local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1509       local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1510       local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1511       local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1512       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1513                                vget_high_s8(input2));
1514       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1515                                vget_high_s8(input3));
1516       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1517                                vget_high_s8(input2));
1518       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1519                                vget_high_s8(input3));
1520       // Pairwise add and accumulate into 32-bit accumulators
1521       row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1522       row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1523       row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1524       row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1525 
1526       // Rows 2 and 3:
1527       // Load 64 bytes of weights values from each row. Convert to signed int8
1528       // by flipping the sign bit (i.e. subtracting 128, the required
1529       // zero_point value).
1530       weights00 = vreinterpretq_s8_u8(veorq_u8(
1531           signbit,
1532           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
1533       weights01 = vreinterpretq_s8_u8(veorq_u8(
1534           signbit,
1535           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
1536       weights02 = vreinterpretq_s8_u8(veorq_u8(
1537           signbit,
1538           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
1539       weights03 = vreinterpretq_s8_u8(veorq_u8(
1540           signbit,
1541           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
1542       weights10 = vreinterpretq_s8_u8(veorq_u8(
1543           signbit,
1544           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
1545       weights11 = vreinterpretq_s8_u8(veorq_u8(
1546           signbit,
1547           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
1548       weights12 = vreinterpretq_s8_u8(veorq_u8(
1549           signbit,
1550           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
1551       weights13 = vreinterpretq_s8_u8(veorq_u8(
1552           signbit,
1553           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
1554       // Multiply-accumulate into local 16-bit accumulators.
1555       // We can accumulate two products without overflow because weights are
1556       // required to never be -128, so each product is at most 127^2 in absolute
1557       // value.
1558       local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1559       local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1560       local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1561       local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1562       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1563                                vget_high_s8(input0));
1564       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1565                                vget_high_s8(input1));
1566       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1567                                vget_high_s8(input0));
1568       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1569                                vget_high_s8(input1));
1570       // Pairwise add and accumulate into 32-bit accumulators
1571       row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1572       row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1573       row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1574       row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1575       // Multiply-accumulate into local 16-bit accumulators.
1576       // We can accumulate two products without overflow because weights are
1577       // required to never be -128, so each product is at most 127^2 in absolute
1578       // value.
1579       local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1580       local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1581       local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1582       local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1583       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1584                                vget_high_s8(input2));
1585       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1586                                vget_high_s8(input3));
1587       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1588                                vget_high_s8(input2));
1589       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1590                                vget_high_s8(input3));
1591       // Pairwise add and accumulate into 32-bit accumulators
1592       row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1593       row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1594       row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1595       row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1596     }
1597 
1598     row_accum00 = vaddq_s32(row_accum00, row_accum01);
1599     row_accum10 = vaddq_s32(row_accum10, row_accum11);
1600     row_accum20 = vaddq_s32(row_accum20, row_accum21);
1601     row_accum30 = vaddq_s32(row_accum30, row_accum31);
1602     // Horizontally reduce accumulators
1603     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1604         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1605     pairwise_reduced_acc_0 =
1606         vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
1607     pairwise_reduced_acc_1 =
1608         vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
1609     pairwise_reduced_acc_2 =
1610         vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
1611     pairwise_reduced_acc_3 =
1612         vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
1613     const int32x2_t reduced_lo =
1614         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1615     const int32x2_t reduced_hi =
1616         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1617     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1618     // Add bias values.
1619     reduced = vaddq_s32(reduced, bias_vec);
1620     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1621     // Multiply by the fixed-point multiplier.
1622     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1623     // Rounding-shift-right.
1624     using gemmlowp::RoundingDivideByPOT;
1625     reduced = RoundingDivideByPOT(reduced, right_shift);
1626     // Narrow values down to 16 bit signed.
1627     const int16x4_t res16 = vqmovn_s32(reduced);
1628     vst1_s16(output_ptr, res16);
1629     output_ptr += 4;
1630   }
1631 }
1632 #endif
1633 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext * gemmlowp_context)1634 inline void FullyConnected(
1635     const FullyConnectedParams& params, const RuntimeShape& input_shape,
1636     const uint8* input_data, const RuntimeShape& filter_shape,
1637     const uint8* filter_data, const RuntimeShape& bias_shape,
1638     const int32* bias_data_int32, const RuntimeShape& output_shape,
1639     int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1640   ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
1641   const int32 input_offset = params.input_offset;
1642   const int32 filter_offset = params.weights_offset;
1643   const int32 output_offset = params.output_offset;
1644   const int32 output_multiplier = params.output_multiplier;
1645   const int output_shift = params.output_shift;
1646   const int32 output_activation_min = params.quantized_activation_min;
1647   const int32 output_activation_max = params.quantized_activation_max;
1648   // This is a copy of the reference implementation. We do not currently have a
1649   // properly optimized version.
1650   (void)gemmlowp_context;  // only used in properly optimized code.
1651   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1652   TFLITE_DCHECK_EQ(output_offset, 0);
1653   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1654   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1655 
1656   // TODO(b/62193649): This really should be:
1657   //     const int batches = ArraySize(output_dims, 1);
1658   // but the current --variable_batch hack consists in overwriting the 3rd
1659   // dimension with the runtime batch size, as we don't keep track for each
1660   // array of which dimension is the batch dimension in it.
1661   const int output_dim_count = output_shape.DimensionsCount();
1662   const int filter_dim_count = filter_shape.DimensionsCount();
1663   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1664   const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
1665                                        output_shape, output_dim_count - 1);
1666   const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
1667 
1668   // Implementation of the fully connected node suited to the inside of an LSTM
1669   // cell. The operands are 8-bit integers, the accumulators are internally
1670   // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so
1671   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
1672   // is explained in the function comment above.
1673 #ifdef GEMMLOWP_NEON
1674   if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
1675       output_activation_max == 32767) {
1676     if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
1677       GEMVForLstmCellWithSymmetricRange(
1678           input_shape, input_data, filter_shape, filter_data, bias_shape,
1679           bias_data_int32, output_multiplier, output_shift, output_shape,
1680           output_data);
1681       return;
1682     }
1683     if (!(output_depth % 4) && !(accum_depth % 8)) {
1684       GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
1685                       filter_offset, bias_shape, bias_data_int32,
1686                       output_multiplier, output_shift, output_shape,
1687                       output_data);
1688       return;
1689     }
1690   }
1691 #endif
1692   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> weights_matrix(
1693       filter_data, output_depth, accum_depth);
1694   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1695       input_data, accum_depth, batches);
1696   gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
1697       output_data, output_depth, batches);
1698   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1699       ColVectorMap;
1700   ColVectorMap bias_vector(bias_data_int32, output_depth);
1701   gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1702   bias_addition_stage.bias_vector = bias_vector;
1703   gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
1704   scale_stage.result_offset_after_shift = 0;
1705   scale_stage.result_fixedpoint_multiplier = output_multiplier;
1706   // Note that this shift is negated wrt ordinary FC.
1707   scale_stage.result_exponent = output_shift;
1708   gemmlowp::OutputStageClamp clamp_stage;
1709   clamp_stage.min = output_activation_min;
1710   clamp_stage.max = output_activation_max;
1711   gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1712   auto output_pipeline =
1713       std::make_tuple(bias_addition_stage, scale_stage, clamp_stage,
1714                       saturating_cast_int16_stage);
1715   gemmlowp::GemmWithOutputPipeline<uint8, int16,
1716                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1717       gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
1718       filter_offset, input_offset, output_pipeline);
1719 }
1720 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1721 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1722                            int32 input_offset, const uint8* filter_data,
1723                            const Dims<4>& filter_dims, int32 filter_offset,
1724                            const int32* bias_data, const Dims<4>& bias_dims,
1725                            int32 output_offset, int32 output_multiplier,
1726                            int output_shift, int32 output_activation_min,
1727                            int32 output_activation_max, uint8* output_data,
1728                            const Dims<4>& output_dims,
1729                            gemmlowp::GemmContext* gemmlowp_context) {
1730   tflite::FullyConnectedParams op_params;
1731   op_params.input_offset = input_offset;
1732   op_params.weights_offset = filter_offset;
1733   op_params.output_offset = output_offset;
1734   op_params.output_multiplier = output_multiplier;
1735   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1736   op_params.output_shift = kReverseShift * output_shift;
1737   op_params.quantized_activation_min = output_activation_min;
1738   op_params.quantized_activation_max = output_activation_max;
1739 
1740   FullyConnected(op_params, DimsToShape(input_dims), input_data,
1741                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1742                  bias_data, DimsToShape(output_dims), output_data,
1743                  gemmlowp_context);
1744 }
1745 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data_int32,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1746 inline void FullyConnected(
1747     const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
1748     const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
1749     const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
1750     int32 output_multiplier, int output_shift, int32 output_activation_min,
1751     int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
1752     gemmlowp::GemmContext* gemmlowp_context) {
1753   tflite::FullyConnectedParams op_params;
1754   op_params.input_offset = input_offset;
1755   op_params.weights_offset = filter_offset;
1756   op_params.output_offset = output_offset;
1757   op_params.output_multiplier = output_multiplier;
1758   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1759   op_params.output_shift = kReverseShift * output_shift;
1760   op_params.quantized_activation_min = output_activation_min;
1761   op_params.quantized_activation_max = output_activation_max;
1762 
1763   FullyConnected(op_params, DimsToShape(input_dims), input_data,
1764                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1765                  bias_data_int32, DimsToShape(output_dims), output_data,
1766                  gemmlowp_context);
1767 }
1768 
1769 // legacy, for compatibility with old checked-in code
1770 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1771 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1772                     int32 input_offset, const uint8* filter_data,
1773                     const Dims<4>& filter_dims, int32 filter_offset,
1774                     const int32* bias_data, const Dims<4>& bias_dims,
1775                     int32 output_offset, int32 output_multiplier,
1776                     int output_shift, int32 output_activation_min,
1777                     int32 output_activation_max, uint8* output_data,
1778                     const Dims<4>& output_dims,
1779                     gemmlowp::GemmContext* gemmlowp_context) {
1780   static_assert(Ac == FusedActivationFunctionType::kNone ||
1781                     Ac == FusedActivationFunctionType::kRelu ||
1782                     Ac == FusedActivationFunctionType::kRelu6 ||
1783                     Ac == FusedActivationFunctionType::kRelu1,
1784                 "");
1785   FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
1786                  filter_offset, bias_data, bias_dims, output_offset,
1787                  output_multiplier, output_shift, output_activation_min,
1788                  output_activation_max, output_data, output_dims,
1789                  gemmlowp_context);
1790 }
1791 
1792 #ifdef USE_NEON
LegacyInt8FullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,int row_start,int row_end)1793 inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl(
1794     const RuntimeShape& input_shape, const int8_t* input_data,
1795     int32 input_offset, const RuntimeShape& filter_shape,
1796     const int8_t* filter_data, int32 filter_offset,
1797     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1798     int32 output_multiplier, int output_shift, int32 output_activation_min,
1799     int32 output_activation_max, const RuntimeShape& output_shape,
1800     int8_t* output_data, int row_start, int row_end) {
1801   ruy::profiler::ScopeLabel label("FullyConnectedAsGEMVInt8/8bit");
1802   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1803   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1804   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1805   const int output_dim_count = output_shape.DimensionsCount();
1806   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1807   const int input_size = FlatSizeSkipDim(input_shape, 0);
1808   static constexpr int kPeel = 4;
1809   const bool shift_left = (output_shift > 0);
1810   TFLITE_DCHECK_GE(row_end - row_start, kPeel);
1811 
1812   for (int out = row_start; out < row_end; out += kPeel) {
1813     out = std::min(out, row_end - kPeel);
1814     int32x4_t acc0 = vdupq_n_s32(0);
1815     int32x4_t acc1 = acc0;
1816     int32x4_t acc2 = acc0;
1817     int32x4_t acc3 = acc0;
1818     const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
1819     const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
1820     int in = 0;
1821     for (; in <= input_size - 16; in += 16) {
1822       const int8x16_t input_val_s8 = vld1q_s8(input_data + in);
1823       const int8_t* filter_ptr = filter_data + in + out * input_size;
1824       int8x16_t filter_val_s8_0 = vld1q_s8(filter_ptr);
1825       filter_ptr += input_size;
1826       int8x16_t filter_val_s8_1 = vld1q_s8(filter_ptr);
1827       filter_ptr += input_size;
1828       int8x16_t filter_val_s8_2 = vld1q_s8(filter_ptr);
1829       filter_ptr += input_size;
1830       int8x16_t filter_val_s8_3 = vld1q_s8(filter_ptr);
1831       int16x8_t input_val_0, input_val_1;
1832       int8x8_t low = vget_low_s8(input_val_s8);
1833       int8x8_t high = vget_high_s8(input_val_s8);
1834       input_val_0 = vmovl_s8(low);
1835       input_val_1 = vmovl_s8(high);
1836       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1837       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1838       low = vget_low_s8(filter_val_s8_0);
1839       high = vget_high_s8(filter_val_s8_0);
1840       int16x8_t filter_val_0_0 = vmovl_s8(low);
1841       int16x8_t filter_val_0_1 = vmovl_s8(high);
1842       filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
1843       filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
1844       low = vget_low_s8(filter_val_s8_1);
1845       high = vget_high_s8(filter_val_s8_1);
1846       int16x8_t filter_val_1_0 = vmovl_s8(low);
1847       int16x8_t filter_val_1_1 = vmovl_s8(high);
1848       filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
1849       filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
1850       low = vget_low_s8(filter_val_s8_2);
1851       high = vget_high_s8(filter_val_s8_2);
1852       int16x8_t filter_val_2_0 = vmovl_s8(low);
1853       int16x8_t filter_val_2_1 = vmovl_s8(high);
1854       filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
1855       filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
1856       low = vget_low_s8(filter_val_s8_3);
1857       high = vget_high_s8(filter_val_s8_3);
1858       int16x8_t filter_val_3_0 = vmovl_s8(low);
1859       int16x8_t filter_val_3_1 = vmovl_s8(high);
1860       filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
1861       filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
1862       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
1863                        vget_low_s16(input_val_0));
1864       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
1865                        vget_low_s16(input_val_0));
1866       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
1867                        vget_low_s16(input_val_0));
1868       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
1869                        vget_low_s16(input_val_0));
1870       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
1871                        vget_low_s16(input_val_1));
1872       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
1873                        vget_low_s16(input_val_1));
1874       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
1875                        vget_low_s16(input_val_1));
1876       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
1877                        vget_low_s16(input_val_1));
1878       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
1879                        vget_high_s16(input_val_0));
1880       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
1881                        vget_high_s16(input_val_0));
1882       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
1883                        vget_high_s16(input_val_0));
1884       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
1885                        vget_high_s16(input_val_0));
1886       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
1887                        vget_high_s16(input_val_1));
1888       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
1889                        vget_high_s16(input_val_1));
1890       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
1891                        vget_high_s16(input_val_1));
1892       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
1893                        vget_high_s16(input_val_1));
1894     }
1895     for (; in <= input_size - 8; in += 8) {
1896       const int8x8_t input_val_s8 = vld1_s8(input_data + in);
1897       const int8_t* filter_ptr = filter_data + in + out * input_size;
1898       int8x8_t filter_val_s8_0 = vld1_s8(filter_ptr);
1899       filter_ptr += input_size;
1900       int8x8_t filter_val_s8_1 = vld1_s8(filter_ptr);
1901       filter_ptr += input_size;
1902       int8x8_t filter_val_s8_2 = vld1_s8(filter_ptr);
1903       filter_ptr += input_size;
1904       int8x8_t filter_val_s8_3 = vld1_s8(filter_ptr);
1905       int16x8_t input_val = vmovl_s8(input_val_s8);
1906       input_val = vaddq_s16(input_val, input_offset_vec);
1907       int16x8_t filter_val_0 = vmovl_s8(filter_val_s8_0);
1908       filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
1909       int16x8_t filter_val_1 = vmovl_s8(filter_val_s8_1);
1910       filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
1911       int16x8_t filter_val_2 = vmovl_s8(filter_val_s8_2);
1912       filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
1913       int16x8_t filter_val_3 = vmovl_s8(filter_val_s8_3);
1914       filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
1915       acc0 =
1916           vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
1917       acc1 =
1918           vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
1919       acc2 =
1920           vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
1921       acc3 =
1922           vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
1923       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
1924                        vget_high_s16(input_val));
1925       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
1926                        vget_high_s16(input_val));
1927       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
1928                        vget_high_s16(input_val));
1929       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
1930                        vget_high_s16(input_val));
1931     }
1932     if (in < input_size) {
1933       int32 buf[16];
1934       vst1q_s32(buf + 0, acc0);
1935       vst1q_s32(buf + 4, acc1);
1936       vst1q_s32(buf + 8, acc2);
1937       vst1q_s32(buf + 12, acc3);
1938       for (; in < input_size; in++) {
1939         int lane = (in + 8 - input_size) % 4;
1940         const int32 input_val = input_data[in] + input_offset;
1941         for (int k = 0; k < kPeel; k++) {
1942           int32 filter_val =
1943               filter_data[in + (out + k) * input_size] + filter_offset;
1944           buf[lane + 4 * k] += filter_val * input_val;
1945         }
1946       }
1947       acc0 = vld1q_s32(buf + 0);
1948       acc1 = vld1q_s32(buf + 4);
1949       acc2 = vld1q_s32(buf + 8);
1950       acc3 = vld1q_s32(buf + 12);
1951     }
1952 
1953     // Horizontally reduce accumulators
1954     int32x2_t pairwise_reduced_acc_0 =
1955         vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
1956     int32x2_t pairwise_reduced_acc_1 =
1957         vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
1958     int32x2_t pairwise_reduced_acc_2 =
1959         vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
1960     int32x2_t pairwise_reduced_acc_3 =
1961         vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
1962     const int32x2_t reduced_lo =
1963         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1964     const int32x2_t reduced_hi =
1965         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1966     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1967     // Add bias values.
1968     int32x4_t bias_vec = vld1q_s32(bias_data + out);
1969     reduced = vaddq_s32(reduced, bias_vec);
1970     if (shift_left) {
1971       const int32 multiplier_power_of_two = 1 << output_shift;
1972       reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
1973       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1974     } else {
1975       // Multiply by the fixed-point multiplier.
1976       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1977       // Rounding-shift-right.
1978       using gemmlowp::RoundingDivideByPOT;
1979       reduced = RoundingDivideByPOT(reduced, -output_shift);
1980     }
1981     // Add the output offset.
1982     const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1983     reduced = vaddq_s32(reduced, output_offset_vec);
1984     // Narrow values down to 16 bit signed.
1985     const int16x4_t res16 = vqmovn_s32(reduced);
1986     // Narrow values down to 8 bit signed, saturating.
1987     int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16));
1988     // Apply the clamping from the activation function
1989     res8 = vmax_s8(res8, vdup_n_s8(output_activation_min));
1990     res8 = vmin_s8(res8, vdup_n_s8(output_activation_max));
1991     // Store results to destination.
1992     vst1_lane_s8(output_data + out + 0, res8, 0);
1993     vst1_lane_s8(output_data + out + 1, res8, 1);
1994     vst1_lane_s8(output_data + out + 2, res8, 2);
1995     vst1_lane_s8(output_data + out + 3, res8, 3);
1996   }
1997 }
1998 
1999 struct LegacyInt8FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyInt8FullyConnectedAsGEMVWorkerTaskLegacyInt8FullyConnectedAsGEMVWorkerTask2000   LegacyInt8FullyConnectedAsGEMVWorkerTask(
2001       const RuntimeShape& input_shape, const int8_t* input_data,
2002       int32 input_offset, const RuntimeShape& filter_shape,
2003       const int8_t* filter_data, int32 filter_offset,
2004       const RuntimeShape& bias_shape, const int32* bias_data,
2005       int32 output_offset, int32 output_multiplier, int output_shift,
2006       int32 output_activation_min, int32 output_activation_max,
2007       const RuntimeShape& output_shape, int8_t* output_data, int row_start,
2008       int row_end)
2009       : input_shape_(input_shape),
2010         input_data_(input_data),
2011         input_offset_(input_offset),
2012         filter_shape_(filter_shape),
2013         filter_data_(filter_data),
2014         filter_offset_(filter_offset),
2015         bias_shape_(bias_shape),
2016         bias_data_(bias_data),
2017         output_offset_(output_offset),
2018         output_multiplier_(output_multiplier),
2019         output_shift_(output_shift),
2020         output_activation_min_(output_activation_min),
2021         output_activation_max_(output_activation_max),
2022         output_shape_(output_shape),
2023         output_data_(output_data),
2024         row_start_(row_start),
2025         row_end_(row_end) {}
2026 
RunLegacyInt8FullyConnectedAsGEMVWorkerTask2027   void Run() override {
2028     LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2029         input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
2030         filter_offset_, bias_shape_, bias_data_, output_offset_,
2031         output_multiplier_, output_shift_, output_activation_min_,
2032         output_activation_max_, output_shape_, output_data_, row_start_,
2033         row_end_);
2034   }
2035 
2036   const RuntimeShape& input_shape_;
2037   const int8_t* input_data_;
2038   int32 input_offset_;
2039   const RuntimeShape& filter_shape_;
2040   const int8_t* filter_data_;
2041   int32 filter_offset_;
2042   const RuntimeShape& bias_shape_;
2043   const int32* bias_data_;
2044   int32 output_offset_;
2045   int32 output_multiplier_;
2046   int output_shift_;
2047   int32 output_activation_min_;
2048   int32 output_activation_max_;
2049   const RuntimeShape& output_shape_;
2050   int8_t* output_data_;
2051   int row_start_;
2052   int row_end_;
2053 };
2054 
LegacyInt8FullyConnectedAsGEMV(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,gemmlowp::GemmContext * gemmlowp_context)2055 inline void LegacyInt8FullyConnectedAsGEMV(
2056     const RuntimeShape& input_shape, const int8_t* input_data,
2057     int32 input_offset, const RuntimeShape& filter_shape,
2058     const int8_t* filter_data, int32 filter_offset,
2059     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
2060     int32 output_multiplier, int output_shift, int32 output_activation_min,
2061     int32 output_activation_max, const RuntimeShape& output_shape,
2062     int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) {
2063   const int output_dim_count = output_shape.DimensionsCount();
2064   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2065   const int output_rows = output_shape.Dims(output_dim_count - 1);
2066   const int input_size = FlatSizeSkipDim(input_shape, 0);
2067   static constexpr int kKernelRows = 4;
2068   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2069       gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
2070   if (thread_count == 1) {
2071     // Single-thread case: do the computation on the current thread, don't
2072     // use a threadpool
2073     LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2074         input_shape, input_data, input_offset, filter_shape, filter_data,
2075         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
2076         output_shift, output_activation_min, output_activation_max,
2077         output_shape, output_data, 0, output_rows);
2078     return;
2079   }
2080 
2081   // Multi-threaded case: use the gemmlowp context's threadpool.
2082   TFLITE_DCHECK_GT(thread_count, 1);
2083   std::vector<LegacyInt8FullyConnectedAsGEMVWorkerTask> tasks;
2084   // TODO(b/131746020) don't create new heap allocations every time.
2085   // At least we make it a single heap allocation by using reserve().
2086   tasks.reserve(thread_count);
2087   const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2088       gemmlowp::CeilQuotient(output_rows, thread_count));
2089   int row_start = 0;
2090   for (int i = 0; i < thread_count; ++i) {
2091     int row_end = std::min(output_rows, row_start + kRowsPerWorker);
2092     tasks.emplace_back(input_shape, input_data, input_offset, filter_shape,
2093                        filter_data, filter_offset, bias_shape, bias_data,
2094                        output_offset, output_multiplier, output_shift,
2095                        output_activation_min, output_activation_max,
2096                        output_shape, output_data, row_start, row_end);
2097     row_start = row_end;
2098   }
2099   TFLITE_DCHECK_EQ(row_start, output_rows);
2100   gemmlowp_context->workers_pool()->Execute(tasks.size(), tasks.data());
2101 }
2102 #endif  // USE_NEON
2103 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int8 * output_data,gemmlowp::GemmContext * gemmlowp_context)2104 inline void FullyConnected(
2105     const FullyConnectedParams& params, const RuntimeShape& input_shape,
2106     const int8* input_data, const RuntimeShape& filter_shape,
2107     const int8* filter_data, const RuntimeShape& bias_shape,
2108     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
2109     gemmlowp::GemmContext* gemmlowp_context) {
2110   ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
2111 
2112 #ifdef USE_NEON
2113   const int32 input_offset = params.input_offset;
2114   const int32 filter_offset = params.weights_offset;
2115   const int32 output_offset = params.output_offset;
2116   const int32 output_multiplier = params.output_multiplier;
2117   const int output_shift = params.output_shift;
2118   const int32 output_activation_min = params.quantized_activation_min;
2119   const int32 output_activation_max = params.quantized_activation_max;
2120   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
2121   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2122   // TODO(b/62193649): This really should be:
2123   //     const int batches = ArraySize(output_dims, 1);
2124   // but the current --variable_batch hack consists in overwriting the 3rd
2125   // dimension with the runtime batch size, as we don't keep track for each
2126   // array of which dimension is the batch dimension in it.
2127   const int output_dim_count = output_shape.DimensionsCount();
2128   const int filter_dim_count = filter_shape.DimensionsCount();
2129   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2130   if (batches == 1) {
2131     const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
2132                                         output_shape, output_dim_count - 1);
2133     if (output_size >= 4) {
2134       return LegacyInt8FullyConnectedAsGEMV(
2135           input_shape, input_data, input_offset, filter_shape, filter_data,
2136           filter_offset, bias_shape, bias_data, output_offset,
2137           output_multiplier, output_shift, output_activation_min,
2138           output_activation_max, output_shape, output_data, gemmlowp_context);
2139     }
2140   }
2141 #endif  // USE_NEON
2142 
2143 #ifdef GEMMLOWP_NEON
2144   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
2145   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
2146   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
2147   const int output_rows = output_shape.Dims(output_dim_count - 1);
2148   TFLITE_DCHECK_EQ(output_rows, filter_rows);
2149   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2150 
2151   gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2152       filter_data, output_rows, filter_cols, filter_cols);
2153   gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::ColMajor> input_matrix(
2154       input_data, filter_cols, batches, filter_cols);
2155   gemmlowp::MatrixMap<int8, gemmlowp::MapOrder::ColMajor> output_matrix(
2156       output_data, output_rows, batches, output_rows);
2157   const auto& output_pipeline = GemmlowpOutputPipelineInt8::MakeExp(
2158       bias_data, output_rows, output_offset, output_multiplier, output_shift,
2159       output_activation_min, output_activation_max);
2160 
2161   gemmlowp::GemmWithOutputPipeline<
2162       int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
2163       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2164       filter_offset, input_offset, output_pipeline);
2165   return;
2166 #endif  // GEMMLOWP_NEON
2167 
2168   // If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference
2169   // implementation.
2170   reference_integer_ops::FullyConnected(params, input_shape, input_data,
2171                                         filter_shape, filter_data, bias_shape,
2172                                         bias_data, output_shape, output_data);
2173 }
2174 
2175 struct LegacyShuffledFullyConnectedWorkerTask : gemmlowp::Task {
LegacyShuffledFullyConnectedWorkerTaskLegacyShuffledFullyConnectedWorkerTask2176   LegacyShuffledFullyConnectedWorkerTask(const uint8* input_data,
2177                                          const int8* shuffled_weights_data,
2178                                          int batches, int output_depth,
2179                                          int output_stride, int accum_depth,
2180                                          const int32* bias_data,
2181                                          int32 output_multiplier,
2182                                          int output_shift, int16* output_data)
2183       : input_data_(input_data),
2184         shuffled_weights_data_(shuffled_weights_data),
2185         batches_(batches),
2186         output_depth_(output_depth),
2187         output_stride_(output_stride),
2188         accum_depth_(accum_depth),
2189         bias_data_(bias_data),
2190         output_multiplier_(output_multiplier),
2191         output_shift_(output_shift),
2192         output_data_(output_data) {}
2193 
RunLegacyShuffledFullyConnectedWorkerTask2194   void Run() override {
2195     ShuffledFullyConnectedWorkerImpl(
2196         input_data_, shuffled_weights_data_, batches_, output_depth_,
2197         output_stride_, accum_depth_, bias_data_, output_multiplier_,
2198         output_shift_, output_data_);
2199   }
2200 
2201   const uint8* input_data_;
2202   const int8* shuffled_weights_data_;
2203   int batches_;
2204   int output_depth_;
2205   int output_stride_;
2206   int accum_depth_;
2207   const int32* bias_data_;
2208   int32 output_multiplier_;
2209   int output_shift_;
2210   int16* output_data_;
2211 };
2212 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2213 inline void ShuffledFullyConnected(
2214     const FullyConnectedParams& params, const RuntimeShape& input_shape,
2215     const uint8* input_data, const RuntimeShape& weights_shape,
2216     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
2217     const int32* bias_data, const RuntimeShape& output_shape,
2218     int16* output_data, uint8* shuffled_input_workspace_data,
2219     gemmlowp::GemmContext* gemmlowp_context) {
2220   ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
2221   const int32 output_multiplier = params.output_multiplier;
2222   const int output_shift = params.output_shift;
2223   const int32 output_activation_min = params.quantized_activation_min;
2224   const int32 output_activation_max = params.quantized_activation_max;
2225   (void)gemmlowp_context;  // only used in optimized code.
2226   TFLITE_DCHECK_EQ(output_activation_min, -32768);
2227   TFLITE_DCHECK_EQ(output_activation_max, 32767);
2228   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
2229   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2230   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2231   // TODO(b/62193649): This really should be:
2232   //     const int batches = ArraySize(output_dims, 1);
2233   // but the current --variable_batch hack consists in overwriting the 3rd
2234   // dimension with the runtime batch size, as we don't keep track for each
2235   // array of which dimension is the batch dimension in it.
2236   const int output_dim_count = output_shape.DimensionsCount();
2237   const int weights_dim_count = weights_shape.DimensionsCount();
2238   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2239   const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
2240                                        output_shape, output_dim_count - 1);
2241   const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
2242   TFLITE_DCHECK((accum_depth % 16) == 0);
2243   TFLITE_DCHECK((output_depth % 4) == 0);
2244   // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
2245   // so that just reinterpreting them as int8 values is equivalent to
2246   // subtracting 128 from them, thus implementing for free the subtraction of
2247   // the zero_point value 128.
2248   const int8* int8_shuffled_weights_data =
2249       reinterpret_cast<const int8*>(shuffled_weights_data);
2250 
2251   // Shuffling and xoring of input activations into the workspace buffer
2252   if (batches == 1) {
2253 #ifdef USE_NEON
2254     const uint8x16_t signbit = vdupq_n_u8(0x80);
2255     for (int i = 0; i < accum_depth; i += 16) {
2256       uint8x16_t val = vld1q_u8(input_data + i);
2257       val = veorq_u8(val, signbit);
2258       vst1q_u8(shuffled_input_workspace_data + i, val);
2259     }
2260 #else
2261     for (int i = 0; i < accum_depth; i++) {
2262       shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
2263     }
2264 #endif
2265   } else if (batches == 4) {
2266     uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
2267     int c = 0;
2268 #ifdef USE_NEON
2269     const uint8x16_t signbit = vdupq_n_u8(0x80);
2270     for (c = 0; c < accum_depth; c += 16) {
2271       const uint8* src_data_ptr = input_data + c;
2272       uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
2273       uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
2274       uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
2275       uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
2276       val0 = veorq_u8(val0, signbit);
2277       val1 = veorq_u8(val1, signbit);
2278       val2 = veorq_u8(val2, signbit);
2279       val3 = veorq_u8(val3, signbit);
2280       vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
2281       vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
2282       vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
2283       vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
2284       shuffled_input_workspace_ptr += 64;
2285     }
2286 #else
2287     for (c = 0; c < accum_depth; c += 16) {
2288       for (int b = 0; b < 4; b++) {
2289         const uint8* src_data_ptr = input_data + b * accum_depth + c;
2290         for (int j = 0; j < 16; j++) {
2291           uint8 src_val = *src_data_ptr++;
2292           // Flip the sign bit, so that the kernel will only need to
2293           // reinterpret these uint8 values as int8, getting for free the
2294           // subtraction of the zero_point value 128.
2295           uint8 dst_val = src_val ^ 0x80;
2296           *shuffled_input_workspace_ptr++ = dst_val;
2297         }
2298       }
2299     }
2300 #endif
2301   } else {
2302     TFLITE_DCHECK(false);
2303     return;
2304   }
2305 
2306   static constexpr int kKernelRows = 4;
2307   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2308       gemmlowp_context->max_num_threads(), output_depth, batches, accum_depth);
2309   if (thread_count == 1) {
2310     // Single-thread case: do the computation on the current thread, don't
2311     // use a threadpool
2312     ShuffledFullyConnectedWorkerImpl(
2313         shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
2314         output_depth, output_depth, accum_depth, bias_data, output_multiplier,
2315         output_shift, output_data);
2316     return;
2317   }
2318 
2319   // Multi-threaded case: use the gemmlowp context's threadpool.
2320   TFLITE_DCHECK_GT(thread_count, 1);
2321   std::vector<gemmlowp::Task*> tasks(thread_count);
2322   const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2323       gemmlowp::CeilQuotient(output_depth, thread_count));
2324   int row_start = 0;
2325   for (int i = 0; i < thread_count; i++) {
2326     int row_end = std::min(output_depth, row_start + kRowsPerWorker);
2327     tasks[i] = new LegacyShuffledFullyConnectedWorkerTask(
2328         shuffled_input_workspace_data,
2329         int8_shuffled_weights_data + row_start * accum_depth, batches,
2330         row_end - row_start, output_depth, accum_depth, bias_data + row_start,
2331         output_multiplier, output_shift, output_data + row_start);
2332     row_start = row_end;
2333   }
2334   TFLITE_DCHECK_EQ(row_start, output_depth);
2335   gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
2336 }
2337 
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2338 inline void ShuffledFullyConnected(
2339     const uint8* input_data, const Dims<4>& input_dims,
2340     const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
2341     const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
2342     int output_shift, int32 output_activation_min, int32 output_activation_max,
2343     int16* output_data, const Dims<4>& output_dims,
2344     uint8* shuffled_input_workspace_data,
2345     gemmlowp::GemmContext* gemmlowp_context) {
2346   tflite::FullyConnectedParams op_params;
2347   op_params.output_multiplier = output_multiplier;
2348   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2349   op_params.output_shift = kReverseShift * output_shift;
2350   op_params.quantized_activation_min = output_activation_min;
2351   op_params.quantized_activation_max = output_activation_max;
2352 
2353   ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
2354                          DimsToShape(weights_dims), shuffled_weights_data,
2355                          DimsToShape(bias_dims), bias_data,
2356                          DimsToShape(output_dims), output_data,
2357                          shuffled_input_workspace_data, gemmlowp_context);
2358 }
2359 
2360 template <typename T>
ExtractPatchIntoBufferColumn(const Dims<4> & input_dims,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)2361 inline void ExtractPatchIntoBufferColumn(
2362     const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
2363     int stride_width, int stride_height, int pad_width, int pad_height,
2364     int in_width, int in_height, int in_depth, int single_buffer_length,
2365     int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
2366   ExtractPatchIntoBufferColumn(
2367       DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
2368       stride_height, pad_width, pad_height, in_width, in_height, in_depth,
2369       single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
2370 }
2371 
2372 template <typename T>
DilatedIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2373 void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
2374                    const Dims<4>& filter_dims, int stride_width,
2375                    int stride_height, int dilation_width_factor,
2376                    int dilation_height_factor, int pad_width, int pad_height,
2377                    const Dims<4>& output_dims, uint8 zero_byte,
2378                    T* im2col_data) {
2379   tflite::ConvParams op_params;
2380   // Padding type is ignored, but still set.
2381   op_params.padding_type = PaddingType::kSame;
2382   op_params.padding_values.width = pad_width;
2383   op_params.padding_values.height = pad_height;
2384   op_params.stride_width = stride_width;
2385   op_params.stride_height = stride_height;
2386   op_params.dilation_width_factor = dilation_width_factor;
2387   op_params.dilation_height_factor = dilation_height_factor;
2388 
2389   DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2390                 DimsToShape(filter_dims), DimsToShape(output_dims),
2391                 im2col_data);
2392 }
2393 
2394 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2395 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
2396             int stride_height, int pad_width, int pad_height, int kheight,
2397             int kwidth, uint8 zero_byte, T* output_data,
2398             const Dims<4>& output_dims) {
2399   tflite::ConvParams op_params;
2400   // Padding type is ignored, but still set.
2401   op_params.padding_type = PaddingType::kSame;
2402   op_params.padding_values.width = pad_width;
2403   op_params.padding_values.height = pad_height;
2404   op_params.stride_width = stride_width;
2405   op_params.stride_height = stride_height;
2406   op_params.dilation_width_factor = 1;
2407   op_params.dilation_height_factor = 1;
2408 
2409   Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
2410          input_data, DimsToShape(output_dims), output_data);
2411 }
2412 
2413 // legacy, for compatibility with old checked-in code
2414 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2415 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2416             int pad_width, int pad_height, int kheight, int kwidth,
2417             uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2418   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2419          kwidth, zero_byte, output_data, output_dims);
2420 }
2421 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2422 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2423                  const float* input_data, const RuntimeShape& filter_shape,
2424                  const float* filter_data, const RuntimeShape& bias_shape,
2425                  const float* bias_data, const RuntimeShape& output_shape,
2426                  float* output_data, const RuntimeShape& im2col_shape,
2427                  float* im2col_data) {
2428   const int stride_width = params.stride_width;
2429   const int stride_height = params.stride_height;
2430   const int dilation_width_factor = params.dilation_width_factor;
2431   const int dilation_height_factor = params.dilation_height_factor;
2432   const float output_activation_min = params.float_activation_min;
2433   const float output_activation_max = params.float_activation_max;
2434   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2435   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2436   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2437 
2438   (void)im2col_data;
2439   (void)im2col_shape;
2440   ruy::profiler::ScopeLabel label("Conv");
2441 
2442   // NB: the float 0.0f value is represented by all zero bytes.
2443   const uint8 float_zero_byte = 0x00;
2444   const float* gemm_input_data = nullptr;
2445   const RuntimeShape* gemm_input_shape = nullptr;
2446   const int filter_width = filter_shape.Dims(2);
2447   const int filter_height = filter_shape.Dims(1);
2448   const bool need_dilated_im2col =
2449       dilation_width_factor != 1 || dilation_height_factor != 1;
2450   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2451                            filter_width != 1 || filter_height != 1;
2452   if (need_dilated_im2col) {
2453     DilatedIm2col(params, float_zero_byte, input_shape, input_data,
2454                   filter_shape, output_shape, im2col_data);
2455     gemm_input_data = im2col_data;
2456     gemm_input_shape = &im2col_shape;
2457   } else if (need_im2col) {
2458     TFLITE_DCHECK(im2col_data);
2459     Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
2460            input_data, im2col_shape, im2col_data);
2461     gemm_input_data = im2col_data;
2462     gemm_input_shape = &im2col_shape;
2463   } else {
2464     // TODO(aselle): We need to make sure to not send im2col if it is not
2465     // needed.
2466     TFLITE_DCHECK(!im2col_data);
2467     gemm_input_data = input_data;
2468     gemm_input_shape = &input_shape;
2469   }
2470 
2471   // The following code computes matrix multiplication c = a * transponse(b)
2472   // with CBLAS, where:
2473   // * `a` is a matrix with dimensions (m, k).
2474   // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
2475   // * `c` is a matrix with dimensions (m, n).
2476   // The naming of variables are aligned with CBLAS specification here.
2477   const float* a = gemm_input_data;
2478   const float* b = filter_data;
2479   float* c = output_data;
2480   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
2481   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
2482   int n = output_shape.Dims(3);
2483   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
2484 
2485 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2486   // The stride of matrix a, b and c respectively.
2487   int stride_a = k;
2488   int stride_b = k;
2489   int stride_c = n;
2490 
2491   cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
2492               stride_a, b, stride_b, 0.0f, c, stride_c);
2493 #else
2494   // When an optimized CBLAS implementation is not available, fall back
2495   // to using Eigen.
2496   typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
2497       Matrix;
2498   typedef Eigen::Map<Matrix> MatrixRef;
2499   typedef Eigen::Map<const Matrix> ConstMatrixRef;
2500 
2501   MatrixRef matrix_c(c, m, n);
2502   ConstMatrixRef matrix_a(a, m, k);
2503   ConstMatrixRef matrix_b(b, n, k);
2504 
2505   // The following special casing for when a or b is a vector is required
2506   // as Eigen seem to fail to make this optimization on its own.
2507   if (n == 1) {
2508     ruy::profiler::ScopeLabel label("GEMV");
2509     matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
2510   } else if (m == 1) {
2511     ruy::profiler::ScopeLabel label("GEMV");
2512     matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
2513   } else {
2514     ruy::profiler::ScopeLabel label("GEMM");
2515     matrix_c.noalias() = matrix_a * matrix_b.transpose();
2516   }
2517 
2518 #endif  //  defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2519 
2520   optimized_ops::AddBiasAndEvalActivationFunction(
2521       output_activation_min, output_activation_max, bias_shape, bias_data,
2522       output_shape, output_data);
2523 }
2524 
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2525 inline void Conv(const float* input_data, const Dims<4>& input_dims,
2526                  const float* filter_data, const Dims<4>& filter_dims,
2527                  const float* bias_data, const Dims<4>& bias_dims,
2528                  int stride_width, int stride_height, int dilation_width_factor,
2529                  int dilation_height_factor, int pad_width, int pad_height,
2530                  float output_activation_min, float output_activation_max,
2531                  float* output_data, const Dims<4>& output_dims,
2532                  float* im2col_data, const Dims<4>& im2col_dims) {
2533   tflite::ConvParams op_params;
2534   // Padding type is ignored, but still set.
2535   op_params.padding_type = PaddingType::kSame;
2536   op_params.padding_values.width = pad_width;
2537   op_params.padding_values.height = pad_height;
2538   op_params.stride_width = stride_width;
2539   op_params.stride_height = stride_height;
2540   op_params.dilation_width_factor = dilation_width_factor;
2541   op_params.dilation_height_factor = dilation_height_factor;
2542   op_params.float_activation_min = output_activation_min;
2543   op_params.float_activation_max = output_activation_max;
2544 
2545   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2546        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2547        output_data, DimsToShape(im2col_dims), im2col_data);
2548 }
2549 
HybridConv(const int8_t * input_data,const Dims<4> & input_dims,const int8_t * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * scaling_factors_ptr,float output_activation_min,float output_activation_max,int32_t * scratch_data,const Dims<4> & scratch_dims,float * output_data,const Dims<4> & output_dims,int8_t * im2col_data,const Dims<4> & im2col_dims,CpuBackendContext * context)2550 inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
2551                        const int8_t* filter_data, const Dims<4>& filter_dims,
2552                        const float* bias_data, const Dims<4>& bias_dims,
2553                        int stride_width, int stride_height, int pad_width,
2554                        int pad_height, float* scaling_factors_ptr,
2555                        float output_activation_min, float output_activation_max,
2556                        int32_t* scratch_data, const Dims<4>& scratch_dims,
2557                        float* output_data, const Dims<4>& output_dims,
2558                        int8_t* im2col_data, const Dims<4>& im2col_dims,
2559                        CpuBackendContext* context) {
2560   tflite::ConvParams op_params;
2561   // Padding type is ignored, but still set.
2562   op_params.padding_type = PaddingType::kSame;
2563   op_params.padding_values.width = pad_width;
2564   op_params.padding_values.height = pad_height;
2565   op_params.stride_width = stride_width;
2566   op_params.stride_height = stride_height;
2567   op_params.float_activation_min = output_activation_min;
2568   op_params.float_activation_max = output_activation_max;
2569 
2570   HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
2571              input_data, DimsToShape(filter_dims), filter_data,
2572              DimsToShape(bias_dims), bias_data, DimsToShape(scratch_dims),
2573              scratch_data, DimsToShape(output_dims), output_data,
2574              DimsToShape(im2col_dims), im2col_data, context);
2575 }
2576 
2577 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2578 void Conv(const float* input_data, const Dims<4>& input_dims,
2579           const float* filter_data, const Dims<4>& filter_dims,
2580           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2581           int stride_height, int dilation_width_factor,
2582           int dilation_height_factor, int pad_width, int pad_height,
2583           float* output_data, const Dims<4>& output_dims, float* im2col_data,
2584           const Dims<4>& im2col_dims) {
2585   float output_activation_min, output_activation_max;
2586   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2587   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2588        stride_width, stride_height, dilation_width_factor,
2589        dilation_height_factor, pad_width, pad_height, output_activation_min,
2590        output_activation_max, output_data, output_dims, im2col_data,
2591        im2col_dims);
2592 }
2593 
2594 // legacy, for compatibility with old checked-in code
2595 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2596 void Conv(const float* input_data, const Dims<4>& input_dims,
2597           const float* filter_data, const Dims<4>& filter_dims,
2598           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2599           int stride_height, int pad_width, int pad_height, float* output_data,
2600           const Dims<4>& output_dims, float* im2col_data,
2601           const Dims<4>& im2col_dims) {
2602   float output_activation_min, output_activation_max;
2603   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2604   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2605        stride_width, stride_height, 1, 1, pad_width, pad_height,
2606        output_activation_min, output_activation_max, output_data, output_dims,
2607        im2col_data, im2col_dims);
2608 }
2609 
2610 // legacy, for compatibility with old checked-in code
2611 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2612 void Conv(const float* input_data, const Dims<4>& input_dims,
2613           const float* filter_data, const Dims<4>& filter_dims,
2614           const float* bias_data, const Dims<4>& bias_dims, int stride,
2615           int pad_width, int pad_height, float* output_data,
2616           const Dims<4>& output_dims, float* im2col_data,
2617           const Dims<4>& im2col_dims) {
2618   Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
2619            bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
2620            output_dims, im2col_data, im2col_dims);
2621 }
2622 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,gemmlowp::GemmContext * gemmlowp_context)2623 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2624                  const uint8* input_data, const RuntimeShape& filter_shape,
2625                  const uint8* filter_data, const RuntimeShape& bias_shape,
2626                  const int32* bias_data, const RuntimeShape& output_shape,
2627                  uint8* output_data, const RuntimeShape& im2col_shape,
2628                  uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
2629   ruy::profiler::ScopeLabel label("Conv/8bit");
2630   const int stride_width = params.stride_width;
2631   const int stride_height = params.stride_height;
2632   const int dilation_width_factor = params.dilation_width_factor;
2633   const int dilation_height_factor = params.dilation_height_factor;
2634   const int32 input_offset = params.input_offset;
2635   const int32 filter_offset = params.weights_offset;
2636   const int32 output_offset = params.output_offset;
2637   const int32 output_multiplier = params.output_multiplier;
2638   const int output_shift = params.output_shift;
2639   const int32 output_activation_min = params.quantized_activation_min;
2640   const int32 output_activation_max = params.quantized_activation_max;
2641   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2642   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2643   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2644 
2645   const uint8* gemm_input_data = nullptr;
2646   const RuntimeShape* gemm_input_shape = nullptr;
2647   const int filter_width = filter_shape.Dims(2);
2648   const int filter_height = filter_shape.Dims(1);
2649   const bool need_dilated_im2col =
2650       dilation_width_factor != 1 || dilation_height_factor != 1;
2651   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2652                            filter_width != 1 || filter_height != 1;
2653   if (need_dilated_im2col) {
2654     TFLITE_DCHECK(im2col_data);
2655     const int input_zero_point = -input_offset;
2656     TFLITE_DCHECK_GE(input_zero_point, 0);
2657     TFLITE_DCHECK_LE(input_zero_point, 255);
2658     DilatedIm2col(params, input_zero_point, input_shape, input_data,
2659                   filter_shape, output_shape, im2col_data);
2660     gemm_input_data = im2col_data;
2661     gemm_input_shape = &im2col_shape;
2662   } else if (need_im2col) {
2663     TFLITE_DCHECK(im2col_data);
2664     const int input_zero_point = -input_offset;
2665     TFLITE_DCHECK_GE(input_zero_point, 0);
2666     TFLITE_DCHECK_LE(input_zero_point, 255);
2667     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2668            input_data, im2col_shape, im2col_data);
2669     gemm_input_data = im2col_data;
2670     gemm_input_shape = &im2col_shape;
2671   } else {
2672     TFLITE_DCHECK(!im2col_data);
2673     gemm_input_data = input_data;
2674     gemm_input_shape = &input_shape;
2675   }
2676 
2677   const int gemm_input_rows = gemm_input_shape->Dims(3);
2678   // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
2679   // The root cause has not yet been identified though. Same applies below for
2680   // the other calls commented out. This is a partial rollback of cl/196819423.
2681   // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
2682   const int gemm_input_cols = gemm_input_shape->Dims(0) *
2683                               gemm_input_shape->Dims(1) *
2684                               gemm_input_shape->Dims(2);
2685   const int filter_rows = filter_shape.Dims(0);
2686   // See b/79927784.
2687   // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2688   const int filter_cols =
2689       filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
2690   const int output_rows = output_shape.Dims(3);
2691   // See b/79927784.
2692   // const int output_cols = FlatSizeSkipDim(output_shape, 3);
2693   const int output_cols =
2694       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
2695   TFLITE_DCHECK_EQ(output_rows, filter_rows);
2696   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
2697   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
2698   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2699 
2700 #ifdef USE_NEON
2701   if (gemm_input_cols == 1 && output_rows >= 4) {
2702     RuntimeShape fc_filter_shape{
2703         filter_shape.Dims(0),
2704         filter_shape.Dims(filter_shape.DimensionsCount() - 1)};
2705 
2706     return FullyConnectedAsGEMV(
2707         *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
2708         filter_data, filter_offset, bias_shape, bias_data, output_offset,
2709         output_multiplier, output_shift, output_activation_min,
2710         output_activation_max, output_shape, output_data, gemmlowp_context);
2711   }
2712 #endif
2713 
2714   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2715       filter_data, filter_rows, filter_cols);
2716   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2717       gemm_input_data, gemm_input_rows, gemm_input_cols);
2718   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2719       output_data, output_rows, output_cols);
2720   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2721       bias_data, output_rows, output_offset, output_multiplier, output_shift,
2722       output_activation_min, output_activation_max);
2723   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2724                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2725       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2726       filter_offset, input_offset, output_pipeline);
2727 }
2728 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2729 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2730                  int32 input_offset, const uint8* filter_data,
2731                  const Dims<4>& filter_dims, int32 filter_offset,
2732                  const int32* bias_data, const Dims<4>& bias_dims,
2733                  int stride_width, int stride_height, int dilation_width_factor,
2734                  int dilation_height_factor, int pad_width, int pad_height,
2735                  int32 output_offset, int32 output_multiplier, int output_shift,
2736                  int32 output_activation_min, int32 output_activation_max,
2737                  uint8* output_data, const Dims<4>& output_dims,
2738                  uint8* im2col_data, const Dims<4>& im2col_dims,
2739                  gemmlowp::GemmContext* gemmlowp_context) {
2740   tflite::ConvParams op_params;
2741   // Padding type is ignored, but still set.
2742   op_params.padding_type = PaddingType::kSame;
2743   op_params.padding_values.width = pad_width;
2744   op_params.padding_values.height = pad_height;
2745   op_params.stride_width = stride_width;
2746   op_params.stride_height = stride_height;
2747   op_params.dilation_width_factor = dilation_width_factor;
2748   op_params.dilation_height_factor = dilation_height_factor;
2749   op_params.input_offset = input_offset;
2750   op_params.weights_offset = filter_offset;
2751   op_params.output_offset = output_offset;
2752   op_params.output_multiplier = output_multiplier;
2753   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2754   op_params.output_shift = kReverseShift * output_shift;
2755   op_params.quantized_activation_min = output_activation_min;
2756   op_params.quantized_activation_max = output_activation_max;
2757 
2758   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2759        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2760        output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
2761 }
2762 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2763 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2764                  int32 input_offset, const uint8* filter_data,
2765                  const Dims<4>& filter_dims, int32 filter_offset,
2766                  const int32* bias_data, const Dims<4>& bias_dims,
2767                  int stride_width, int stride_height, int pad_width,
2768                  int pad_height, int32 output_offset, int32 output_multiplier,
2769                  int output_shift, int32 output_activation_min,
2770                  int32 output_activation_max, uint8* output_data,
2771                  const Dims<4>& output_dims, uint8* im2col_data,
2772                  const Dims<4>& im2col_dims,
2773                  gemmlowp::GemmContext* gemmlowp_context) {
2774   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2775        filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
2776        pad_width, pad_height, output_offset, output_multiplier, output_shift,
2777        output_activation_min, output_activation_max, output_data, output_dims,
2778        im2col_data, im2col_dims, gemmlowp_context);
2779 }
2780 
2781 // legacy, for compatibility with old checked-in code
2782 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2783 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2784                  int32 input_offset, const uint8* filter_data,
2785                  const Dims<4>& filter_dims, int32 filter_offset,
2786                  const int32* bias_data, const Dims<4>& bias_dims,
2787                  int stride_width, int stride_height, int pad_width,
2788                  int pad_height, int32 output_offset, int32 output_multiplier,
2789                  int output_shift, int32 output_activation_min,
2790                  int32 output_activation_max, uint8* output_data,
2791                  const Dims<4>& output_dims, uint8* im2col_data,
2792                  const Dims<4>& im2col_dims,
2793                  gemmlowp::GemmContext* gemmlowp_context) {
2794   static_assert(Ac == FusedActivationFunctionType::kNone ||
2795                     Ac == FusedActivationFunctionType::kRelu ||
2796                     Ac == FusedActivationFunctionType::kRelu6 ||
2797                     Ac == FusedActivationFunctionType::kRelu1,
2798                 "");
2799   if (Ac == FusedActivationFunctionType::kNone) {
2800     TFLITE_DCHECK_EQ(output_activation_min, 0);
2801     TFLITE_DCHECK_EQ(output_activation_max, 255);
2802   }
2803   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2804        filter_offset, bias_data, bias_dims, stride_width, stride_height,
2805        pad_width, pad_height, output_offset, output_multiplier, output_shift,
2806        output_activation_min, output_activation_max, output_data, output_dims,
2807        im2col_data, im2col_dims, gemmlowp_context);
2808 }
2809 
2810 // legacy, for compatibility with old checked-in code
2811 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2812 void Conv(const uint8* input_data, const Dims<4>& input_dims,
2813           int32 input_offset, const uint8* filter_data,
2814           const Dims<4>& filter_dims, int32 filter_offset,
2815           const int32* bias_data, const Dims<4>& bias_dims, int stride,
2816           int pad_width, int pad_height, int32 output_offset,
2817           int32 output_multiplier, int output_shift,
2818           int32 output_activation_min, int32 output_activation_max,
2819           uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
2820           const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
2821   static_assert(Ac == FusedActivationFunctionType::kNone ||
2822                     Ac == FusedActivationFunctionType::kRelu ||
2823                     Ac == FusedActivationFunctionType::kRelu6 ||
2824                     Ac == FusedActivationFunctionType::kRelu1,
2825                 "");
2826   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2827        filter_offset, bias_data, bias_dims, stride, stride, pad_width,
2828        pad_height, output_offset, output_multiplier, output_shift,
2829        output_activation_min, output_activation_max, output_data, output_dims,
2830        im2col_data, im2col_dims, gemmlowp_context);
2831 }
2832 
2833 // legacy, for compatibility with old checked-in code
2834 template <FusedActivationFunctionType Ac, typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2835 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2836             int pad_width, int pad_height, int kheight, int kwidth,
2837             uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2838   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2839          kwidth, zero_byte, output_data, output_dims);
2840 }
2841 
2842 // legacy, for compatibility with old checked-in code
2843 template <FusedActivationFunctionType Ac>
ConvAsGemm(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)2844 void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
2845                 const float* filter_data, const Dims<4>& filter_dims,
2846                 const float* bias_data, const Dims<4>& bias_dims,
2847                 float* output_data, const Dims<4>& output_dims) {
2848   ruy::profiler::ScopeLabel label("ConvAsGemm");
2849 
2850   const auto input_matrix_map =
2851       MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
2852   const auto filter_matrix_map =
2853       MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
2854   auto output_matrix_map =
2855       MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
2856 
2857   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
2858 
2859   AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
2860                                        output_dims);
2861 }
2862 
2863 // legacy, for compatibility with old checked-in code
2864 template <FusedActivationFunctionType Ac>
ConvAsGemm(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)2865 void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
2866                 int32 input_offset, const uint8* filter_data,
2867                 const Dims<4>& filter_dims, int32 filter_offset,
2868                 const int32* bias_data, const Dims<4>& bias_dims,
2869                 int32 output_offset, int32 output_multiplier, int output_shift,
2870                 int32 output_activation_min, int32 output_activation_max,
2871                 uint8* output_data, const Dims<4>& output_dims,
2872                 gemmlowp::GemmContext* gemmlowp_context) {
2873   ruy::profiler::ScopeLabel label("ConvAsGemm/8bit");
2874   static_assert(Ac == FusedActivationFunctionType::kNone ||
2875                     Ac == FusedActivationFunctionType::kRelu ||
2876                     Ac == FusedActivationFunctionType::kRelu6 ||
2877                     Ac == FusedActivationFunctionType::kRelu1,
2878                 "");
2879   const int input_rows = input_dims.sizes[0];
2880   const int input_cols = FlatSizeSkipDim(input_dims, 0);
2881   const int filter_rows = filter_dims.sizes[3];
2882   const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
2883   const int output_rows = output_dims.sizes[0];
2884   const int output_cols = FlatSizeSkipDim(output_dims, 0);
2885   TFLITE_DCHECK_EQ(output_rows, filter_rows);
2886   TFLITE_DCHECK_EQ(output_cols, input_cols);
2887   TFLITE_DCHECK_EQ(filter_cols, input_rows);
2888   TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
2889   TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
2890   TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
2891   TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
2892   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2893       filter_data, output_rows, filter_cols, filter_cols);
2894   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2895       input_data, filter_cols, output_cols, filter_cols);
2896   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2897       output_data, output_rows, output_cols, output_rows);
2898   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2899       bias_data, output_rows, output_offset, output_multiplier, -output_shift,
2900       output_activation_min, output_activation_max);
2901   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2902                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2903       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2904       filter_offset, input_offset, output_pipeline);
2905 }
2906 
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2907 inline void TransposeConv(
2908     const ConvParams& params, const RuntimeShape& input_shape,
2909     const float* input_data, const RuntimeShape& filter_shape,
2910     const float* filter_data, const RuntimeShape& output_shape,
2911     float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
2912   ruy::profiler::ScopeLabel label("TransposeConv");
2913   // Note we could use transposed weights with forward conv for unstrided
2914   // cases. But we are already getting good performance with this code as-is.
2915   TFLITE_DCHECK(im2col_data);
2916   TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
2917                   output_shape, im2col_data);
2918 
2919   const auto im2col_matrix_map =
2920       MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
2921   const auto filter_matrix_map =
2922       MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
2923   auto output_matrix_map =
2924       MapAsMatrixWithLastDimAsRows(output_data, output_shape);
2925 
2926   Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
2927 }
2928 
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2929 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
2930                           const float* filter_data, const Dims<4>& filter_dims,
2931                           int stride_width, int stride_height, int pad_width,
2932                           int pad_height, float* output_data,
2933                           const Dims<4>& output_dims, float* im2col_data,
2934                           const Dims<4>& im2col_dims) {
2935   tflite::ConvParams op_params;
2936   // Padding type is ignored, but still set.
2937   op_params.padding_type = PaddingType::kSame;
2938   op_params.padding_values.width = pad_width;
2939   op_params.padding_values.height = pad_height;
2940   op_params.stride_width = stride_width;
2941   op_params.stride_height = stride_height;
2942 
2943   TransposeConv(op_params, DimsToShape(input_dims), input_data,
2944                 DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
2945                 output_data, DimsToShape(im2col_dims), im2col_data);
2946 }
2947 
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)2948 inline void TransposeConvV2(
2949     const ConvParams& params, const RuntimeShape& input_shape,
2950     const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
2951     const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
2952     float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
2953     CpuBackendContext* cpu_backend_context) {
2954   TransposeConvV2(params, input_shape, input_data, hwoi_ordered_filter_shape,
2955                   hwoi_ordered_filter_data, /*bias_shape*/ RuntimeShape(),
2956                   /*bias_data*/ nullptr, output_shape, output_data,
2957                   col2im_shape, col2im_data, cpu_backend_context);
2958 }
2959 
2960 template <typename T>
TransposeIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2961 void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
2962                      const Dims<4>& filter_dims, int stride_width,
2963                      int stride_height, int pad_width, int pad_height,
2964                      const Dims<4>& output_dims, uint8 zero_byte,
2965                      T* im2col_data) {
2966   tflite::ConvParams op_params;
2967   // Padding type is ignored, but still set.
2968   op_params.padding_type = PaddingType::kSame;
2969   op_params.padding_values.width = pad_width;
2970   op_params.padding_values.height = pad_height;
2971   op_params.stride_width = stride_width;
2972   op_params.stride_height = stride_height;
2973 
2974   TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2975                   DimsToShape(filter_dims), DimsToShape(output_dims),
2976                   im2col_data);
2977 }
2978 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)2979 inline void LstmCell(
2980     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2981     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2982     const float* prev_activ_data, const RuntimeShape& weights_shape,
2983     const float* weights_data, const RuntimeShape& unextended_bias_shape,
2984     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2985     const float* prev_state_data,
2986     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2987     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2988     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2989     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
2990   ruy::profiler::ScopeLabel label("LstmCell");
2991   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2992   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2993   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2994   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2995   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2996   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2997   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2998   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2999   const RuntimeShape input_shape =
3000       RuntimeShape::ExtendedShape(4, unextended_input_shape);
3001   const RuntimeShape prev_activ_shape =
3002       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3003   const RuntimeShape bias_shape =
3004       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3005   const RuntimeShape prev_state_shape =
3006       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3007   const RuntimeShape output_state_shape =
3008       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3009   const RuntimeShape output_activ_shape =
3010       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3011   const RuntimeShape concat_temp_shape =
3012       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3013   const RuntimeShape activ_temp_shape =
3014       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3015   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3016 
3017   const int weights_dim_count = weights_shape.DimensionsCount();
3018   MatchingDim(  // batches
3019       input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
3020       output_state_shape, 0, output_activ_shape, 0);
3021   MatchingDim(  // height
3022       input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
3023       output_state_shape, 1, output_activ_shape, 1);
3024   MatchingDim(  // width
3025       input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
3026       output_state_shape, 2, output_activ_shape, 2);
3027   const int input_depth = input_shape.Dims(3);
3028   const int prev_activ_depth = prev_activ_shape.Dims(3);
3029   const int total_input_depth = prev_activ_depth + input_depth;
3030   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3031                    total_input_depth);
3032   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3033   const int intern_activ_depth =
3034       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3035   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3036                    intern_activ_depth * total_input_depth);
3037   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3038   const int output_depth =
3039       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3040                   3, output_activ_shape, 3);
3041   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3042 
3043   // Concatenate prev_activ and input data together
3044   std::vector<float const*> concat_input_arrays_data;
3045   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
3046   concat_input_arrays_data.push_back(input_data);
3047   concat_input_arrays_data.push_back(prev_activ_data);
3048   concat_input_arrays_shapes.push_back(&input_shape);
3049   concat_input_arrays_shapes.push_back(&prev_activ_shape);
3050   tflite::ConcatenationParams concat_params;
3051   concat_params.axis = 3;
3052   concat_params.inputs_count = concat_input_arrays_data.size();
3053   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
3054                 &(concat_input_arrays_data[0]), concat_temp_shape,
3055                 concat_temp_data);
3056 
3057   // Fully connected
3058   tflite::FullyConnectedParams fc_params;
3059   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
3060   fc_params.float_activation_max = std::numeric_limits<float>::max();
3061   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
3062                  weights_data, bias_shape, bias_data, activ_temp_shape,
3063                  activ_temp_data);
3064 
3065   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
3066   // operations.
3067   ArrayMap<float> activ_temp_map =
3068       MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
3069   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
3070                                             activ_temp_map.cols());
3071   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
3072                                            activ_temp_map.cols());
3073   auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
3074                                              activ_temp_map.cols());
3075   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
3076                                              activ_temp_map.cols());
3077   ArrayMap<const float> prev_state_map =
3078       MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
3079   ArrayMap<float> output_state_map =
3080       MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
3081   ArrayMap<float> output_activ_map =
3082       MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
3083 
3084   // Combined memory state and final output calculation
3085   ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
3086   output_state_map =
3087       input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3088           new_input_sm.tanh() +
3089       forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3090           prev_state_map;
3091   output_activ_map =
3092       output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3093       output_state_map.tanh();
3094 }
3095 
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)3096 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
3097                      const float* prev_activ_data,
3098                      const Dims<4>& prev_activ_dims, const float* weights_data,
3099                      const Dims<4>& weights_dims, const float* bias_data,
3100                      const Dims<4>& bias_dims, const float* prev_state_data,
3101                      const Dims<4>& prev_state_dims, float* output_state_data,
3102                      const Dims<4>& output_state_dims, float* output_activ_data,
3103                      const Dims<4>& output_activ_dims, float* concat_temp_data,
3104                      const Dims<4>& concat_temp_dims, float* activ_temp_data,
3105                      const Dims<4>& activ_temp_dims) {
3106   tflite::LstmCellParams op_params;
3107   // Float LSTM cell does not need parameters to be set: leave untouched.
3108 
3109   LstmCell(op_params, DimsToShape(input_dims), input_data,
3110            DimsToShape(prev_activ_dims), prev_activ_data,
3111            DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
3112            bias_data, DimsToShape(prev_state_dims), prev_state_data,
3113            DimsToShape(output_state_dims), output_state_data,
3114            DimsToShape(output_activ_dims), output_activ_data,
3115            DimsToShape(concat_temp_dims), concat_temp_data,
3116            DimsToShape(activ_temp_dims), activ_temp_data);
3117 }
3118 
3119 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,gemmlowp::GemmContext * gemmlowp_context)3120 inline void LstmCell(
3121     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3122     const uint8* input_data_uint8,
3123     const RuntimeShape& unextended_prev_activ_shape,
3124     const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
3125     const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
3126     const int32* bias_data_int32,
3127     const RuntimeShape& unextended_prev_state_shape,
3128     const int16* prev_state_data_int16,
3129     const RuntimeShape& unextended_output_state_shape,
3130     int16* output_state_data_int16,
3131     const RuntimeShape& unextended_output_activ_shape,
3132     uint8* output_activ_data_uint8,
3133     const RuntimeShape& unextended_concat_temp_shape,
3134     uint8* concat_temp_data_uint8,
3135     const RuntimeShape& unextended_activ_temp_shape,
3136     int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
3137   ruy::profiler::ScopeLabel label(
3138       "LstmCell/quantized (8bit external, 16bit internal)");
3139   int32 weights_zero_point = params.weights_zero_point;
3140   int32 accum_multiplier = params.accum_multiplier;
3141   int accum_shift = params.accum_shift;
3142   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3143   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3144   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3145   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3146   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3147   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3148   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3149   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3150   const RuntimeShape input_shape =
3151       RuntimeShape::ExtendedShape(4, unextended_input_shape);
3152   const RuntimeShape prev_activ_shape =
3153       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3154   const RuntimeShape bias_shape =
3155       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3156   const RuntimeShape prev_state_shape =
3157       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3158   const RuntimeShape output_state_shape =
3159       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3160   const RuntimeShape output_activ_shape =
3161       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3162   const RuntimeShape concat_temp_shape =
3163       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3164   const RuntimeShape activ_temp_shape =
3165       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3166   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3167 
3168   // Gather dimensions information, and perform consistency checks.
3169   const int weights_dim_count = weights_shape.DimensionsCount();
3170   const int outer_size = MatchingFlatSizeSkipDim(
3171       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
3172       output_activ_shape);
3173   const int input_depth = input_shape.Dims(3);
3174   const int prev_activ_depth = prev_activ_shape.Dims(3);
3175   const int total_input_depth = prev_activ_depth + input_depth;
3176   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3177                    total_input_depth);
3178   const int intern_activ_depth =
3179       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3180   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3181                    intern_activ_depth * total_input_depth);
3182   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3183   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3184   const int output_depth =
3185       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3186                   3, output_activ_shape, 3);
3187   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3188   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
3189   const int fc_output_depth =
3190       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
3191   const int fc_accum_depth = total_input_depth;
3192   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
3193 
3194   // Depth-concatenate prev_activ and input data together.
3195   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
3196                                               prev_activ_data_uint8};
3197   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
3198                                                        &prev_activ_shape};
3199   tflite::ConcatenationParams concat_params;
3200   concat_params.axis = 3;
3201   concat_params.inputs_count = 2;
3202   Concatenation(concat_params, concat_input_arrays_shapes,
3203                 concat_input_arrays_data, concat_temp_shape,
3204                 concat_temp_data_uint8);
3205 
3206   // Implementation of the fully connected node inside the LSTM cell.
3207   // The operands are 8-bit integers, the accumulators are internally 32bit
3208   // integers, and the output is 16-bit fixed-point with 3 integer bits so
3209   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
3210   // is explained in the function comment above.
3211   bool gemm_already_performed = false;
3212 #ifdef GEMMLOWP_NEON
3213   if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
3214     GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
3215                     weights_data_uint8, weights_zero_point, bias_shape,
3216                     bias_data_int32, accum_multiplier, accum_shift,
3217                     activ_temp_shape, activ_temp_data_int16);
3218     gemm_already_performed = true;
3219   }
3220 #endif
3221   if (!gemm_already_performed) {
3222     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor>
3223         weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth);
3224     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
3225         concat_temp_data_uint8, fc_accum_depth, fc_batches);
3226     gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
3227         activ_temp_data_int16, fc_output_depth, fc_batches);
3228     typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
3229         ColVectorMap;
3230     ColVectorMap bias_vector(bias_data_int32, fc_output_depth);
3231     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
3232     bias_addition_stage.bias_vector = bias_vector;
3233     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
3234     scale_stage.result_offset_after_shift = 0;
3235     scale_stage.result_fixedpoint_multiplier = accum_multiplier;
3236     scale_stage.result_exponent = accum_shift;
3237     gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
3238     auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
3239                                            saturating_cast_int16_stage);
3240     gemmlowp::GemmWithOutputPipeline<
3241         uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
3242         gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
3243         -weights_zero_point, -128, output_pipeline);
3244   }
3245 
3246   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3247   // and muls, all done in 16-bit fixed-point.
3248   const int16* input_gate_input_ptr = activ_temp_data_int16;
3249   const int16* input_modulation_gate_input_ptr =
3250       activ_temp_data_int16 + output_depth;
3251   const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3252   const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3253   const int16* prev_state_ptr = prev_state_data_int16;
3254   int16* output_state_data_ptr = output_state_data_int16;
3255   uint8* output_activ_data_ptr = output_activ_data_uint8;
3256 
3257   for (int b = 0; b < outer_size; ++b) {
3258     int c = 0;
3259 #ifdef GEMMLOWP_NEON
3260     for (; c <= output_depth - 8; c += 8) {
3261       // Define the fixed-point data types that we will use here. All use
3262       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3263       // They only differ by the number of integral vs. fractional bits,
3264       // determining the range of values that they can represent.
3265       //
3266       // F0 uses 0 integer bits, range [-1, 1].
3267       // This is the return type of math functions such as tanh, logistic,
3268       // whose range is in [-1, 1].
3269       using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3270       // F3 uses 3 integer bits, range [-8, 8].
3271       // This is the range of the previous fully-connected node's output,
3272       // which is our input here.
3273       using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3274       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3275       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3276       // number of integer bits is currently dictated by the model. See comment
3277       // on the StateIntegerBits template parameter above.
3278       using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3279       // Implementation of input gate, using fixed-point logistic function.
3280       F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3281       input_gate_input_ptr += 8;
3282       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3283       // Implementation of input modulation gate, using fixed-point tanh
3284       // function.
3285       F3 input_modulation_gate_input =
3286           F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3287       input_modulation_gate_input_ptr += 8;
3288       F0 input_modulation_gate_output =
3289           gemmlowp::tanh(input_modulation_gate_input);
3290       // Implementation of forget gate, using fixed-point logistic function.
3291       F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3292       forget_gate_input_ptr += 8;
3293       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3294       // Implementation of output gate, using fixed-point logistic function.
3295       F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3296       output_gate_input_ptr += 8;
3297       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3298       // Implementation of internal multiplication nodes, still in fixed-point.
3299       F0 input_times_input_modulation =
3300           input_gate_output * input_modulation_gate_output;
3301       FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3302       prev_state_ptr += 8;
3303       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3304       // Implementation of internal addition node, saturating.
3305       FS new_state = gemmlowp::SaturatingAdd(
3306           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3307           prev_state_times_forget_state);
3308       // Implementation of last internal Tanh node, still in fixed-point.
3309       // Since a Tanh fixed-point implementation is specialized for a given
3310       // number or integer bits, and each specialization can have a substantial
3311       // code size, and we already used above a Tanh on an input with 3 integer
3312       // bits, and per the table in the above function comment there is no
3313       // significant accuracy to be lost by clamping to [-8, +8] for a
3314       // 3-integer-bits representation, let us just do that. This helps people
3315       // porting this to targets where code footprint must be minimized.
3316       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3317       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3318       // Store the new internal state back to memory, as 16-bit integers.
3319       // Note: here we store the original value with StateIntegerBits, not
3320       // the rescaled 3-integer-bits value fed to tanh.
3321       vst1q_s16(output_state_data_ptr, new_state.raw());
3322       output_state_data_ptr += 8;
3323       // Down-scale the output activations to 8-bit integers, saturating,
3324       // and store back to memory.
3325       int16x8_t rescaled_output_activ =
3326           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3327       int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3328       uint8x8_t uint8_output_activ =
3329           vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3330       vst1_u8(output_activ_data_ptr, uint8_output_activ);
3331       output_activ_data_ptr += 8;
3332     }
3333 #endif
3334     for (; c < output_depth; ++c) {
3335       // Define the fixed-point data types that we will use here. All use
3336       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3337       // They only differ by the number of integral vs. fractional bits,
3338       // determining the range of values that they can represent.
3339       //
3340       // F0 uses 0 integer bits, range [-1, 1].
3341       // This is the return type of math functions such as tanh, logistic,
3342       // whose range is in [-1, 1].
3343       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3344       // F3 uses 3 integer bits, range [-8, 8].
3345       // This is the range of the previous fully-connected node's output,
3346       // which is our input here.
3347       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3348       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3349       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3350       // number of integer bits is currently dictated by the model. See comment
3351       // on the StateIntegerBits template parameter above.
3352       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3353       // Implementation of input gate, using fixed-point logistic function.
3354       F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3355       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3356       // Implementation of input modulation gate, using fixed-point tanh
3357       // function.
3358       F3 input_modulation_gate_input =
3359           F3::FromRaw(*input_modulation_gate_input_ptr++);
3360       F0 input_modulation_gate_output =
3361           gemmlowp::tanh(input_modulation_gate_input);
3362       // Implementation of forget gate, using fixed-point logistic function.
3363       F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3364       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3365       // Implementation of output gate, using fixed-point logistic function.
3366       F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3367       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3368       // Implementation of internal multiplication nodes, still in fixed-point.
3369       F0 input_times_input_modulation =
3370           input_gate_output * input_modulation_gate_output;
3371       FS prev_state = FS::FromRaw(*prev_state_ptr++);
3372       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3373       // Implementation of internal addition node, saturating.
3374       FS new_state = gemmlowp::SaturatingAdd(
3375           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3376           prev_state_times_forget_state);
3377       // Implementation of last internal Tanh node, still in fixed-point.
3378       // Since a Tanh fixed-point implementation is specialized for a given
3379       // number or integer bits, and each specialization can have a substantial
3380       // code size, and we already used above a Tanh on an input with 3 integer
3381       // bits, and per the table in the above function comment there is no
3382       // significant accuracy to be lost by clamping to [-8, +8] for a
3383       // 3-integer-bits representation, let us just do that. This helps people
3384       // porting this to targets where code footprint must be minimized.
3385       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3386       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3387       // Store the new internal state back to memory, as 16-bit integers.
3388       // Note: here we store the original value with StateIntegerBits, not
3389       // the rescaled 3-integer-bits value fed to tanh.
3390       *output_state_data_ptr++ = new_state.raw();
3391       // Down-scale the output activations to 8-bit integers, saturating,
3392       // and store back to memory.
3393       int16 rescaled_output_activ =
3394           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3395       int16 clamped_output_activ =
3396           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3397       *output_activ_data_ptr++ = 128 + clamped_output_activ;
3398     }
3399     input_gate_input_ptr += 3 * output_depth;
3400     input_modulation_gate_input_ptr += 3 * output_depth;
3401     forget_gate_input_ptr += 3 * output_depth;
3402     output_gate_input_ptr += 3 * output_depth;
3403   }
3404 }
3405 
3406 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)3407 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
3408               const uint8* prev_activ_data_uint8,
3409               const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
3410               const Dims<4>& weights_dims, const int32* bias_data_int32,
3411               const Dims<4>& bias_dims, const int16* prev_state_data_int16,
3412               const Dims<4>& prev_state_dims, int16* output_state_data_int16,
3413               const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
3414               const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
3415               const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
3416               const Dims<4>& activ_temp_dims, int32 weights_zero_point,
3417               int32 accum_multiplier, int accum_shift,
3418               gemmlowp::GemmContext* gemmlowp_context) {
3419   tflite::LstmCellParams op_params;
3420   op_params.weights_zero_point = weights_zero_point;
3421   op_params.accum_multiplier = accum_multiplier;
3422   op_params.accum_shift = accum_shift;
3423 
3424   LstmCell<StateIntegerBits>(
3425       op_params, DimsToShape(input_dims), input_data_uint8,
3426       DimsToShape(prev_activ_dims), prev_activ_data_uint8,
3427       DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
3428       bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
3429       DimsToShape(output_state_dims), output_state_data_int16,
3430       DimsToShape(output_activ_dims), output_activ_data_uint8,
3431       DimsToShape(concat_temp_dims), concat_temp_data_uint8,
3432       DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
3433 }
3434 
3435 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3436 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
3437                   const T* input2_data, const Dims<4>& input2_dims,
3438                   T output_activation_min, T output_activation_max,
3439                   T* output_data, const Dims<4>& output_dims) {
3440   tflite::ArithmeticParams op_params;
3441   SetActivationParams(output_activation_min, output_activation_max, &op_params);
3442 
3443   BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
3444                    DimsToShape(input2_dims), input2_data,
3445                    DimsToShape(output_dims), output_data);
3446 }
3447 
3448 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)3449 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
3450                      float* output_data, const RuntimeShape& output_shape) {
3451   static_assert(Ac == FusedActivationFunctionType::kNone, "");
3452   tflite::L2NormalizationParams op_params;
3453   // No params need to be set for float, but reserved in signature for future
3454   // activations.
3455 
3456   L2Normalization(op_params, input_shape, input_data, output_shape,
3457                   output_data);
3458 }
3459 
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)3460 inline void L2Normalization(const uint8* input_data,
3461                             const RuntimeShape& input_shape,
3462                             int32 input_zero_point, uint8* output_data,
3463                             const RuntimeShape& output_shape) {
3464   tflite::L2NormalizationParams op_params;
3465   op_params.input_zero_point = input_zero_point;
3466 
3467   L2Normalization(op_params, input_shape, input_data, output_shape,
3468                   output_data);
3469 }
3470 
3471 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3472 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
3473                      float* output_data, const Dims<4>& output_dims) {
3474   L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
3475                       DimsToShape(output_dims));
3476 }
3477 
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)3478 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
3479                             int32 input_zero_point, uint8* output_data,
3480                             const Dims<4>& output_dims) {
3481   L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
3482                   output_data, DimsToShape(output_dims));
3483 }
3484 
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3485 inline void Relu(const float* input_data, const Dims<4>& input_dims,
3486                  float* output_data, const Dims<4>& output_dims) {
3487   Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3488        output_data);
3489 }
3490 
3491 // legacy, for compatibility with old checked-in code
3492 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3493 void Add(const float* input1_data, const Dims<4>& input1_dims,
3494          const float* input2_data, const Dims<4>& input2_dims,
3495          float* output_data, const Dims<4>& output_dims) {
3496   float output_activation_min, output_activation_max;
3497   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3498 
3499   tflite::ArithmeticParams op_params;
3500   op_params.float_activation_min = output_activation_min;
3501   op_params.float_activation_max = output_activation_max;
3502   Add(op_params, DimsToShape(input1_dims), input1_data,
3503       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3504       output_data);
3505 }
3506 
3507 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3508 inline void Add(int left_shift, const uint8* input1_data,
3509                 const Dims<4>& input1_dims, int32 input1_offset,
3510                 int32 input1_multiplier, int input1_shift,
3511                 const uint8* input2_data, const Dims<4>& input2_dims,
3512                 int32 input2_offset, int32 input2_multiplier, int input2_shift,
3513                 int32 output_offset, int32 output_multiplier, int output_shift,
3514                 int32 output_activation_min, int32 output_activation_max,
3515                 uint8* output_data, const Dims<4>& output_dims) {
3516   constexpr int kReverseShift = -1;
3517   static_assert(Ac == FusedActivationFunctionType::kNone ||
3518                     Ac == FusedActivationFunctionType::kRelu ||
3519                     Ac == FusedActivationFunctionType::kRelu6 ||
3520                     Ac == FusedActivationFunctionType::kRelu1,
3521                 "");
3522   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3523   if (Ac == FusedActivationFunctionType::kNone) {
3524     TFLITE_DCHECK_EQ(output_activation_min, 0);
3525     TFLITE_DCHECK_EQ(output_activation_max, 255);
3526   }
3527 
3528   tflite::ArithmeticParams op_params;
3529   op_params.left_shift = left_shift;
3530   op_params.input1_offset = input1_offset;
3531   op_params.input1_multiplier = input1_multiplier;
3532   op_params.input1_shift = kReverseShift * input1_shift;
3533   op_params.input2_offset = input2_offset;
3534   op_params.input2_multiplier = input2_multiplier;
3535   op_params.input2_shift = kReverseShift * input2_shift;
3536   op_params.output_offset = output_offset;
3537   op_params.output_multiplier = output_multiplier;
3538   op_params.output_shift = kReverseShift * output_shift;
3539   op_params.quantized_activation_min = output_activation_min;
3540   op_params.quantized_activation_max = output_activation_max;
3541   Add(op_params, DimsToShape(input1_dims), input1_data,
3542       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3543       output_data);
3544 }
3545 
3546 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)3547 void Add(const int32* input1_data, const Dims<4>& input1_dims,
3548          const int32* input2_data, const Dims<4>& input2_dims,
3549          int32* output_data, const Dims<4>& output_dims) {
3550   ruy::profiler::ScopeLabel label("Add/int32");
3551   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
3552 
3553   tflite::ArithmeticParams op_params;
3554   op_params.quantized_activation_min = std::numeric_limits<int32>::min();
3555   op_params.quantized_activation_max = std::numeric_limits<int32>::max();
3556   Add(op_params, DimsToShape(input1_dims), input1_data,
3557       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3558       output_data);
3559 }
3560 
3561 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3562 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3563                   const T* input2_data, const Dims<4>& input2_dims,
3564                   T output_activation_min, T output_activation_max,
3565                   T* output_data, const Dims<4>& output_dims) {
3566   tflite::ArithmeticParams op_params;
3567   op_params.float_activation_min = output_activation_min;
3568   op_params.float_activation_max = output_activation_max;
3569   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3570                      DimsToShape(input2_dims), input2_data,
3571                      DimsToShape(output_dims), output_data);
3572 }
3573 
3574 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3575 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
3576                          const Dims<4>& input1_dims, int32 input1_offset,
3577                          int32 input1_multiplier, int input1_shift,
3578                          const uint8* input2_data, const Dims<4>& input2_dims,
3579                          int32 input2_offset, int32 input2_multiplier,
3580                          int input2_shift, int32 output_offset,
3581                          int32 output_multiplier, int output_shift,
3582                          int32 output_activation_min,
3583                          int32 output_activation_max, uint8* output_data,
3584                          const Dims<4>& output_dims) {
3585   constexpr int kReverseShift = -1;
3586   static_assert(Ac == FusedActivationFunctionType::kNone ||
3587                     Ac == FusedActivationFunctionType::kRelu ||
3588                     Ac == FusedActivationFunctionType::kRelu6 ||
3589                     Ac == FusedActivationFunctionType::kRelu1,
3590                 "");
3591   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3592   if (Ac == FusedActivationFunctionType::kNone) {
3593     TFLITE_DCHECK_EQ(output_activation_min, 0);
3594     TFLITE_DCHECK_EQ(output_activation_max, 255);
3595   }
3596 
3597   tflite::ArithmeticParams op_params;
3598   op_params.left_shift = left_shift;
3599   op_params.input1_offset = input1_offset;
3600   op_params.input1_multiplier = input1_multiplier;
3601   op_params.input1_shift = kReverseShift * input1_shift;
3602   op_params.input2_offset = input2_offset;
3603   op_params.input2_multiplier = input2_multiplier;
3604   op_params.input2_shift = kReverseShift * input2_shift;
3605   op_params.output_offset = output_offset;
3606   op_params.output_multiplier = output_multiplier;
3607   op_params.output_shift = kReverseShift * output_shift;
3608   op_params.quantized_activation_min = output_activation_min;
3609   op_params.quantized_activation_max = output_activation_max;
3610   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3611                      DimsToShape(input2_dims), input2_data,
3612                      DimsToShape(output_dims), output_data);
3613 }
3614 
3615 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3616 inline void BroadcastAddFivefold(
3617     int y0, int y1, int y2, int y3, int y4, int left_shift,
3618     const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
3619     int32 input1_multiplier, int input1_shift, const uint8* input2_data,
3620     const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
3621     int input2_shift, int32 output_offset, int32 output_multiplier,
3622     int output_shift, int32 output_activation_min, int32 output_activation_max,
3623     uint8* output_data, const Dims<4>& output_dims) {
3624   constexpr int kReverseShift = -1;
3625   static_assert(Ac == FusedActivationFunctionType::kNone ||
3626                     Ac == FusedActivationFunctionType::kRelu ||
3627                     Ac == FusedActivationFunctionType::kRelu6 ||
3628                     Ac == FusedActivationFunctionType::kRelu1,
3629                 "");
3630   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3631   if (Ac == FusedActivationFunctionType::kNone) {
3632     TFLITE_DCHECK_EQ(output_activation_min, 0);
3633     TFLITE_DCHECK_EQ(output_activation_max, 255);
3634   }
3635   tflite::ArithmeticParams op_params;
3636   op_params.broadcast_category =
3637       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
3638   op_params.left_shift = left_shift;
3639   op_params.input1_offset = input1_offset;
3640   op_params.input1_multiplier = input1_multiplier;
3641   op_params.input1_shift = kReverseShift * input1_shift;
3642   op_params.input2_offset = input2_offset;
3643   op_params.input2_multiplier = input2_multiplier;
3644   op_params.input2_shift = kReverseShift * input2_shift;
3645   op_params.output_offset = output_offset;
3646   op_params.output_multiplier = output_multiplier;
3647   op_params.output_shift = kReverseShift * output_shift;
3648   op_params.quantized_activation_min = output_activation_min;
3649   op_params.quantized_activation_max = output_activation_max;
3650   op_params.broadcast_shape[4] = y0;
3651   op_params.broadcast_shape[3] = y1;
3652   op_params.broadcast_shape[2] = y2;
3653   op_params.broadcast_shape[1] = y3;
3654   op_params.broadcast_shape[0] = y4;
3655   BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
3656                        DimsToShape(input2_dims), input2_data,
3657                        DimsToShape(output_dims), output_data);
3658 }
3659 
3660 // legacy, for compatibility with old checked-in code
3661 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3662 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3663                   const T* input2_data, const Dims<4>& input2_dims,
3664                   T* output_data, const Dims<4>& output_dims) {
3665   T output_activation_min, output_activation_max;
3666   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3667 
3668   BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
3669                output_activation_min, output_activation_max, output_data,
3670                output_dims);
3671 }
3672 
3673 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)3674 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
3675                 int input1_shift, const int16* input2_data,
3676                 const Dims<4>& input2_dims, int input2_shift,
3677                 int16 output_activation_min, int16 output_activation_max,
3678                 int16* output_data, const Dims<4>& output_dims) {
3679   constexpr int kReverseShift = -1;
3680   static_assert(Ac == FusedActivationFunctionType::kNone ||
3681                     Ac == FusedActivationFunctionType::kRelu ||
3682                     Ac == FusedActivationFunctionType::kRelu6 ||
3683                     Ac == FusedActivationFunctionType::kRelu1,
3684                 "");
3685   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3686   if (Ac == FusedActivationFunctionType::kNone) {
3687     TFLITE_DCHECK_EQ(output_activation_min, -32768);
3688     TFLITE_DCHECK_EQ(output_activation_max, 32767);
3689   }
3690 
3691   tflite::ArithmeticParams op_params;
3692   op_params.input1_shift = kReverseShift * input1_shift;
3693   op_params.input2_shift = kReverseShift * input2_shift;
3694   op_params.quantized_activation_min = output_activation_min;
3695   op_params.quantized_activation_max = output_activation_max;
3696   Add(op_params, DimsToShape(input1_dims), input1_data,
3697       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3698       output_data);
3699 }
3700 
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3701 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
3702                 const float* input2_data, const Dims<4>& input2_dims,
3703                 float* output_data, const Dims<4>& output_dims) {
3704   float output_activation_min, output_activation_max;
3705   GetActivationMinMax(FusedActivationFunctionType::kNone,
3706                       &output_activation_min, &output_activation_max);
3707   tflite::ArithmeticParams op_params;
3708   op_params.float_activation_min = output_activation_min;
3709   op_params.float_activation_max = output_activation_max;
3710   Sub(op_params, DimsToShape(input1_dims), input1_data,
3711       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3712       output_data);
3713 }
3714 
3715 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3716 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
3717          const Dims<4>& input2_dims, T* output_data,
3718          const Dims<4>& output_dims) {
3719   T output_activation_min, output_activation_max;
3720   GetActivationMinMax(FusedActivationFunctionType::kNone,
3721                       &output_activation_min, &output_activation_max);
3722   tflite::ArithmeticParams op_params;
3723   op_params.quantized_activation_min = output_activation_min;
3724   op_params.quantized_activation_max = output_activation_max;
3725   Sub(op_params, DimsToShape(input1_dims), input1_data,
3726       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3727       output_data);
3728 }
3729 
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3730 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3731                          int32 input1_offset, const uint8* input2_data,
3732                          const Dims<4>& input2_dims, int32 input2_offset,
3733                          int32 output_offset, int32 output_multiplier,
3734                          int output_shift, int32 output_activation_min,
3735                          int32 output_activation_max, uint8* output_data,
3736                          const Dims<4>& output_dims) {
3737   tflite::ArithmeticParams op_params;
3738   SetActivationParams(output_activation_min, output_activation_max, &op_params);
3739   op_params.input1_offset = input1_offset;
3740   op_params.input2_offset = input2_offset;
3741   op_params.output_offset = output_offset;
3742   op_params.output_multiplier = output_multiplier;
3743   op_params.output_shift = kReverseShift * output_shift;
3744 
3745   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3746                      DimsToShape(input2_dims), input2_data,
3747                      DimsToShape(output_dims), output_data);
3748 }
3749 
3750 // legacy, for compatibility with old checked-in code
3751 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3752 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3753                          int32 input1_offset, const uint8* input2_data,
3754                          const Dims<4>& input2_dims, int32 input2_offset,
3755                          int32 output_offset, int32 output_multiplier,
3756                          int output_shift, int32 output_activation_min,
3757                          int32 output_activation_max, uint8* output_data,
3758                          const Dims<4>& output_dims) {
3759   BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
3760                input2_dims, input2_offset, output_offset, output_multiplier,
3761                output_shift, output_activation_min, output_activation_max,
3762                output_data, output_dims);
3763 }
3764 
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3765 inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
3766                         int stride_width, int stride_height, int pad_width,
3767                         int pad_height, int kwidth, int kheight,
3768                         float output_activation_min,
3769                         float output_activation_max, float* output_data,
3770                         const Dims<4>& output_dims) {
3771   tflite::PoolParams params;
3772   params.stride_height = stride_height;
3773   params.stride_width = stride_width;
3774   params.filter_height = kheight;
3775   params.filter_width = kwidth;
3776   params.padding_values.height = pad_height;
3777   params.padding_values.width = pad_width;
3778   params.float_activation_min = output_activation_min;
3779   params.float_activation_max = output_activation_max;
3780   AveragePool(params, DimsToShape(input_dims), input_data,
3781               DimsToShape(output_dims), output_data);
3782 }
3783 
3784 // legacy, for compatibility with old checked-in code
3785 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3786 void AveragePool(const float* input_data, const Dims<4>& input_dims,
3787                  int stride_width, int stride_height, int pad_width,
3788                  int pad_height, int kwidth, int kheight, float* output_data,
3789                  const Dims<4>& output_dims) {
3790   float output_activation_min, output_activation_max;
3791   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3792 
3793   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
3794               pad_height, kwidth, kheight, output_activation_min,
3795               output_activation_max, output_data, output_dims);
3796 }
3797 
3798 // legacy, for compatibility with old checked-in code
3799 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3800 void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
3801                  int pad_width, int pad_height, int filter_width,
3802                  int filter_height, float* output_data,
3803                  const Dims<4>& output_dims) {
3804   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3805                   filter_width, filter_height, output_data, output_dims);
3806 }
3807 
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3808 inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3809                         int stride_width, int stride_height, int pad_width,
3810                         int pad_height, int filter_width, int filter_height,
3811                         int32 output_activation_min,
3812                         int32 output_activation_max, uint8* output_data,
3813                         const Dims<4>& output_dims) {
3814   tflite::PoolParams params;
3815   params.stride_height = stride_height;
3816   params.stride_width = stride_width;
3817   params.filter_height = filter_height;
3818   params.filter_width = filter_width;
3819   params.padding_values.height = pad_height;
3820   params.padding_values.width = pad_width;
3821   params.quantized_activation_min = output_activation_min;
3822   params.quantized_activation_max = output_activation_max;
3823   AveragePool(params, DimsToShape(input_dims), input_data,
3824               DimsToShape(output_dims), output_data);
3825 }
3826 
3827 // legacy, for compatibility with old checked-in code
3828 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3829 void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3830                  int stride_width, int stride_height, int pad_width,
3831                  int pad_height, int filter_width, int filter_height,
3832                  int32 output_activation_min, int32 output_activation_max,
3833                  uint8* output_data, const Dims<4>& output_dims) {
3834   static_assert(Ac == FusedActivationFunctionType::kNone ||
3835                     Ac == FusedActivationFunctionType::kRelu ||
3836                     Ac == FusedActivationFunctionType::kRelu6 ||
3837                     Ac == FusedActivationFunctionType::kRelu1,
3838                 "");
3839   if (Ac == FusedActivationFunctionType::kNone) {
3840     TFLITE_DCHECK_EQ(output_activation_min, 0);
3841     TFLITE_DCHECK_EQ(output_activation_max, 255);
3842   }
3843   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
3844               pad_height, filter_width, filter_height, output_activation_min,
3845               output_activation_max, output_data, output_dims);
3846 }
3847 
3848 // legacy, for compatibility with old checked-in code
3849 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3850 void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3851                  int pad_width, int pad_height, int filter_width,
3852                  int filter_height, int32 output_activation_min,
3853                  int32 output_activation_max, uint8* output_data,
3854                  const Dims<4>& output_dims) {
3855   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3856                   filter_width, filter_height, output_activation_min,
3857                   output_activation_max, output_data, output_dims);
3858 }
3859 
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3860 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
3861                     int stride_width, int stride_height, int pad_width,
3862                     int pad_height, int kwidth, int kheight,
3863                     float output_activation_min, float output_activation_max,
3864                     float* output_data, const Dims<4>& output_dims) {
3865   tflite::PoolParams params;
3866   params.stride_height = stride_height;
3867   params.stride_width = stride_width;
3868   params.filter_height = kheight;
3869   params.filter_width = kwidth;
3870   params.padding_values.height = pad_height;
3871   params.padding_values.width = pad_width;
3872   params.float_activation_min = output_activation_min;
3873   params.float_activation_max = output_activation_max;
3874   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3875           output_data);
3876 }
3877 
3878 // legacy, for compatibility with old checked-in code
3879 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3880 void MaxPool(const float* input_data, const Dims<4>& input_dims,
3881              int stride_width, int stride_height, int pad_width, int pad_height,
3882              int kwidth, int kheight, float* output_data,
3883              const Dims<4>& output_dims) {
3884   float output_activation_min, output_activation_max;
3885   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3886   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3887           pad_height, kwidth, kheight, output_activation_min,
3888           output_activation_max, output_data, output_dims);
3889 }
3890 
3891 // legacy, for compatibility with old checked-in code
3892 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3893 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
3894              int pad_width, int pad_height, int filter_width, int filter_height,
3895              float* output_data, const Dims<4>& output_dims) {
3896   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3897               filter_width, filter_height, output_data, output_dims);
3898 }
3899 
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3900 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3901                     int stride_width, int stride_height, int pad_width,
3902                     int pad_height, int filter_width, int filter_height,
3903                     int32 output_activation_min, int32 output_activation_max,
3904                     uint8* output_data, const Dims<4>& output_dims) {
3905   PoolParams params;
3906   params.stride_height = stride_height;
3907   params.stride_width = stride_width;
3908   params.filter_height = filter_height;
3909   params.filter_width = filter_width;
3910   params.padding_values.height = pad_height;
3911   params.padding_values.width = pad_width;
3912   params.quantized_activation_min = output_activation_min;
3913   params.quantized_activation_max = output_activation_max;
3914   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3915           output_data);
3916 }
3917 
3918 // legacy, for compatibility with old checked-in code
3919 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3920 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3921              int stride_width, int stride_height, int pad_width, int pad_height,
3922              int filter_width, int filter_height, int32 output_activation_min,
3923              int32 output_activation_max, uint8* output_data,
3924              const Dims<4>& output_dims) {
3925   static_assert(Ac == FusedActivationFunctionType::kNone ||
3926                     Ac == FusedActivationFunctionType::kRelu ||
3927                     Ac == FusedActivationFunctionType::kRelu6 ||
3928                     Ac == FusedActivationFunctionType::kRelu1,
3929                 "");
3930   if (Ac == FusedActivationFunctionType::kNone) {
3931     TFLITE_DCHECK_EQ(output_activation_min, 0);
3932     TFLITE_DCHECK_EQ(output_activation_max, 255);
3933   }
3934   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3935           pad_height, filter_width, filter_height, output_activation_min,
3936           output_activation_max, output_data, output_dims);
3937 }
3938 
3939 // legacy, for compatibility with old checked-in code
3940 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3941 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3942              int pad_width, int pad_height, int filter_width, int filter_height,
3943              int32 output_activation_min, int32 output_activation_max,
3944              uint8* output_data, const Dims<4>& output_dims) {
3945   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3946               filter_width, filter_height, output_activation_min,
3947               output_activation_max, output_data, output_dims);
3948 }
3949 
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3950 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
3951                    int stride_width, int stride_height, int pad_width,
3952                    int pad_height, int filter_width, int filter_height,
3953                    float output_activation_min, float output_activation_max,
3954                    float* output_data, const Dims<4>& output_dims) {
3955   PoolParams params;
3956   params.stride_height = stride_height;
3957   params.stride_width = stride_width;
3958   params.filter_height = filter_height;
3959   params.filter_width = filter_width;
3960   params.padding_values.height = pad_height;
3961   params.padding_values.width = pad_width;
3962   params.float_activation_min = output_activation_min;
3963   params.float_activation_max = output_activation_max;
3964   L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3965          output_data);
3966 }
3967 
3968 // legacy, for compatibility with old checked-in code
3969 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3970 void L2Pool(const float* input_data, const Dims<4>& input_dims,
3971             int stride_width, int stride_height, int pad_width, int pad_height,
3972             int filter_width, int filter_height, float* output_data,
3973             const Dims<4>& output_dims) {
3974   float output_activation_min, output_activation_max;
3975   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3976   L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
3977          pad_height, filter_width, filter_height, output_activation_min,
3978          output_activation_max, output_data, output_dims);
3979 }
3980 
3981 // legacy, for compatibility with old checked-in code
3982 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3983 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
3984             int pad_width, int pad_height, int filter_width, int filter_height,
3985             float* output_data, const Dims<4>& output_dims) {
3986   L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3987              filter_width, filter_height, output_data, output_dims);
3988 }
3989 
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3990 inline void Softmax(const SoftmaxParams& params,
3991                     const RuntimeShape& input_shape, const uint8* input_data,
3992                     const RuntimeShape& output_shape, uint8* output_data) {
3993   const int32 input_beta_multiplier = params.input_multiplier;
3994   const int32 input_beta_left_shift = params.input_left_shift;
3995   const int diff_min = params.diff_min;
3996   // The representation chosen for the input to the exp() function is Q5.26.
3997   // We need to leave extra space since values that we skip might be as large as
3998   // -32 before multiplying by input_beta_multiplier, and therefore as large as
3999   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
4000   // accumulation, but exp(-16) definitely is.
4001   static const int kScaledDiffIntegerBits = 5;
4002   static const int kAccumulationIntegerBits = 12;
4003   using FixedPointScaledDiff =
4004       gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4005   using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4006   using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4007 
4008   ruy::profiler::ScopeLabel label("Softmax/8bit");
4009   const int trailing_dim = input_shape.DimensionsCount() - 1;
4010   const int outer_size =
4011       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4012   const int depth =
4013       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4014 
4015   for (int b = 0; b < outer_size; ++b) {
4016     const uint8* input_data_ptr = input_data + b * depth;
4017     uint8* output_data_ptr = output_data + b * depth;
4018 
4019     // Determine the largest entry in the current row
4020     uint8 max_in_row = 0;
4021     {
4022       int c = 0;
4023 #ifdef USE_NEON
4024       uint8x16_t max16_0 = vdupq_n_u8(0);
4025       uint8x16_t max16_1 = vdupq_n_u8(0);
4026       for (; c <= depth - 32; c += 32) {
4027         max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
4028         max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
4029       }
4030       uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
4031       if (c <= depth - 16) {
4032         max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
4033         c += 16;
4034       }
4035       uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
4036       if (c <= depth - 8) {
4037         max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
4038         c += 8;
4039       }
4040       uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
4041       uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
4042       uint8x8_t max1 = vpmax_u8(max2, max2);
4043       max_in_row = vget_lane_u8(max1, 0);
4044 #endif
4045       for (; c < depth; ++c) {
4046         max_in_row = std::max(max_in_row, input_data_ptr[c]);
4047       }
4048     }
4049 
4050 #ifdef USE_NEON
4051     using FixedPointAccumInt32x4 =
4052         gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
4053     using FixedPointScaledDiffInt32x4 =
4054         gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
4055     using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
4056     FixedPoint0Int32x4 input_beta_multiplier_f0 =
4057         FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
4058     int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
4059 #endif
4060 
4061     // Compute the sum of exponentials of the differences of entries in the
4062     // current row from the largest entry in the current row.
4063     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4064     {
4065       int c = 0;
4066 #ifdef USE_NEON
4067       int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
4068       FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
4069       FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
4070       FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
4071       for (; c <= depth - 8; c += 8) {
4072         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4073         int16x8_t input_diff_s16 =
4074             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4075         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4076         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4077         int32x4_t mask_0 =
4078             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
4079         int32x4_t mask_1 =
4080             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
4081         FixedPointScaledDiffInt32x4 scaled_diff_0 =
4082             input_beta_multiplier_f0 *
4083             FixedPointScaledDiffInt32x4::FromRaw(
4084                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4085         FixedPointScaledDiffInt32x4 scaled_diff_1 =
4086             input_beta_multiplier_f0 *
4087             FixedPointScaledDiffInt32x4::FromRaw(
4088                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4089         FixedPointAccumInt32x4 exps_0 =
4090             gemmlowp::Rescale<kAccumulationIntegerBits>(
4091                 exp_on_negative_values(scaled_diff_0));
4092         FixedPointAccumInt32x4 exps_1 =
4093             gemmlowp::Rescale<kAccumulationIntegerBits>(
4094                 exp_on_negative_values(scaled_diff_1));
4095         FixedPointAccumInt32x4 masked_exps_0 =
4096             SelectUsingMask(mask_0, exps_0, zeros);
4097         FixedPointAccumInt32x4 masked_exps_1 =
4098             SelectUsingMask(mask_1, exps_1, zeros);
4099         sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
4100         sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
4101       }
4102       int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
4103       int32x2_t sum_of_exps_reduced_2 =
4104           vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
4105                    vget_high_s32(sum_of_exps_reduced_4));
4106       int32x2_t sum_of_exps_reduced_1 =
4107           vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
4108       sum_of_exps =
4109           FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
4110 #endif
4111       for (; c < depth; ++c) {
4112         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4113         if (input_diff >= diff_min) {
4114           const int32 input_diff_rescaled =
4115               MultiplyByQuantizedMultiplierGreaterThanOne(
4116                   input_diff, input_beta_multiplier, input_beta_left_shift);
4117           const FixedPointScaledDiff scaled_diff_f8 =
4118               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4119           sum_of_exps =
4120               sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4121                                 exp_on_negative_values(scaled_diff_f8));
4122         }
4123       }
4124     }
4125 
4126     // Compute the fixed-point multiplier and shift that we need to apply to
4127     // perform a division by the above-computed sum-of-exponentials.
4128     int num_bits_over_unit = 0;
4129     FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
4130         sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
4131 
4132     // Compute the quotients of exponentials of differences of entries in the
4133     // current row from the largest entry, over the previously-computed sum of
4134     // exponentials.
4135     {
4136       int c = 0;
4137 #ifdef USE_NEON
4138       int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
4139       for (; c <= depth - 8; c += 8) {
4140         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4141         int16x8_t input_diff_s16 =
4142             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4143         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4144         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4145         uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
4146         FixedPointScaledDiffInt32x4 scaled_diff_0 =
4147             input_beta_multiplier_f0 *
4148             FixedPointScaledDiffInt32x4::FromRaw(
4149                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4150         FixedPointScaledDiffInt32x4 scaled_diff_1 =
4151             input_beta_multiplier_f0 *
4152             FixedPointScaledDiffInt32x4::FromRaw(
4153                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4154         FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
4155         FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
4156         int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
4157             vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
4158             num_bits_over_unit + 31 - 8);
4159         int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
4160             vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
4161             num_bits_over_unit + 31 - 8);
4162         int16x8_t output_s16 =
4163             vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
4164         uint8x8_t output_u8 = vqmovun_s16(output_s16);
4165         uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
4166         vst1_u8(output_data_ptr + c, masked_output);
4167       }
4168 #endif
4169       for (; c < depth; ++c) {
4170         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4171         if (input_diff >= diff_min) {
4172           const int32 input_diff_rescaled =
4173               MultiplyByQuantizedMultiplierGreaterThanOne(
4174                   input_diff, input_beta_multiplier, input_beta_left_shift);
4175           const FixedPointScaledDiff scaled_diff_f8 =
4176               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4177 
4178           FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
4179           int32 unsat_output = gemmlowp::RoundingDivideByPOT(
4180               (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
4181 
4182           output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
4183 
4184         } else {
4185           output_data_ptr[c] = 0;
4186         }
4187       }
4188     }
4189   }
4190 }
4191 
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)4192 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
4193                     float beta, float* output_data,
4194                     const RuntimeShape& output_shape) {
4195   SoftmaxParams params;
4196   params.beta = beta;
4197   Softmax(params, input_shape, input_data, output_shape, output_data);
4198 }
4199 
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)4200 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
4201                     float beta, float* output_data,
4202                     const Dims<4>& output_dims) {
4203   Softmax(input_data, DimsToShape(input_dims), beta, output_data,
4204           DimsToShape(output_dims));
4205 }
4206 
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4207 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
4208                     int32 input_beta_multiplier, int32 input_beta_left_shift,
4209                     int diff_min, uint8* output_data,
4210                     const RuntimeShape& output_shape) {
4211   SoftmaxParams params;
4212   params.input_multiplier = input_beta_multiplier;
4213   params.input_left_shift = input_beta_left_shift;
4214   params.diff_min = diff_min;
4215   Softmax(params, input_shape, input_data, output_shape, output_data);
4216 }
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4217 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
4218                     int32 input_beta_multiplier, int32 input_beta_left_shift,
4219                     int diff_min, uint8* output_data,
4220                     const Dims<4>& output_dims) {
4221   Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
4222           input_beta_left_shift, diff_min, output_data,
4223           DimsToShape(output_dims));
4224 }
4225 
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)4226 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
4227                        float* output_data, const RuntimeShape& output_shape) {
4228   SoftmaxParams params;
4229   // No params currently used for float LogSoftmax.
4230   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
4231 }
4232 
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4233 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
4234                        float* output_data, const Dims<4>& output_dims) {
4235   LogSoftmax(input_data, DimsToShape(input_dims), output_data,
4236              DimsToShape(output_dims));
4237 }
4238 
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4239 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
4240                        int32 input_multiplier, int32 input_left_shift,
4241                        int32 reverse_scaling_divisor,
4242                        int32 reverse_scaling_right_shift, int diff_min,
4243                        uint8* output_data, const RuntimeShape& output_shape) {
4244   SoftmaxParams params;
4245   params.input_multiplier = input_multiplier;
4246   params.input_left_shift = input_left_shift;
4247   params.reverse_scaling_divisor = reverse_scaling_divisor;
4248   params.reverse_scaling_right_shift = reverse_scaling_right_shift;
4249   params.diff_min = diff_min;
4250   reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4251                             output_data);
4252 }
4253 
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4254 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
4255                        int32 input_multiplier, int32 input_left_shift,
4256                        int32 reverse_scaling_divisor,
4257                        int32 reverse_scaling_right_shift, int diff_min,
4258                        uint8* output_data, const Dims<4>& output_dims) {
4259   reference_ops::LogSoftmax(
4260       input_data, DimsToShape(input_dims), input_multiplier, input_left_shift,
4261       reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
4262       output_data, DimsToShape(output_dims));
4263 }
4264 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4265 inline void Logistic(const LogisticParams& params,
4266                      const RuntimeShape& input_shape, const uint8* input_data,
4267                      const RuntimeShape& output_shape, uint8* output_data) {
4268   ruy::profiler::ScopeLabel label("Logistic/Uint8");
4269   const int32 input_zero_point = params.input_zero_point;
4270   const int32 input_range_radius = params.input_range_radius;
4271   const int32 input_multiplier = params.input_multiplier;
4272   const int input_left_shift = params.input_left_shift;
4273   const int size = MatchingFlatSize(input_shape, output_shape);
4274 
4275   int c = 0;
4276 #ifdef USE_NEON
4277   // Handle 16 values at a time
4278   for (; c <= size - 16; c += 16) {
4279     // Read input uint8 values, cast to int16 and subtract input_zero_point
4280     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4281     int16x8_t input_val_centered_0 =
4282         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4283                   vdupq_n_s16(input_zero_point));
4284     int16x8_t input_val_centered_1 =
4285         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4286                   vdupq_n_s16(input_zero_point));
4287 
4288     // Prepare the bit masks that we will use at the end to implement the logic
4289     // that was expressed in the scalar code with branching:
4290     //   if (input_val_centered < -input_range_radius) {
4291     //     output_val = 0;
4292     //   } else if (input_val_centered > input_range_radius) {
4293     //     output_val = 255;
4294     //   } else {
4295     //     ...
4296     uint16x8_t mask_rightclamp_0 =
4297         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4298     uint16x8_t mask_rightclamp_1 =
4299         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4300     uint16x8_t mask_leftclamp_0 =
4301         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4302     uint16x8_t mask_leftclamp_1 =
4303         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4304     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4305                                              vshrn_n_u16(mask_rightclamp_1, 8));
4306     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4307                                             vshrn_n_u16(mask_leftclamp_1, 8));
4308 
4309     // This performs what is expressed in the scalar code as
4310     // const int32 input_val_rescaled =
4311     //     MultiplyByQuantizedMultiplierGreaterThanOne(
4312     //         input_val_centered, input_multiplier, input_left_shift);
4313     int32x4_t input_val_rescaled_0 =
4314         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4315                   vdupq_n_s32(input_left_shift));
4316     int32x4_t input_val_rescaled_1 =
4317         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4318                   vdupq_n_s32(input_left_shift));
4319     int32x4_t input_val_rescaled_2 =
4320         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4321                   vdupq_n_s32(input_left_shift));
4322     int32x4_t input_val_rescaled_3 =
4323         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4324                   vdupq_n_s32(input_left_shift));
4325     input_val_rescaled_0 =
4326         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4327     input_val_rescaled_1 =
4328         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4329     input_val_rescaled_2 =
4330         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4331     input_val_rescaled_3 =
4332         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4333 
4334     // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
4335     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4336     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4337     const FixedPoint4 input_val_f4_0 =
4338         FixedPoint4::FromRaw(input_val_rescaled_0);
4339     const FixedPoint4 input_val_f4_1 =
4340         FixedPoint4::FromRaw(input_val_rescaled_1);
4341     const FixedPoint4 input_val_f4_2 =
4342         FixedPoint4::FromRaw(input_val_rescaled_2);
4343     const FixedPoint4 input_val_f4_3 =
4344         FixedPoint4::FromRaw(input_val_rescaled_3);
4345     const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
4346     const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
4347     const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
4348     const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
4349 
4350     // Divide by 2^23 as in the scalar code
4351     using gemmlowp::RoundingDivideByPOT;
4352     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
4353     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
4354     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
4355     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
4356 
4357     // Cast output values to uint8, saturating
4358     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4359                                               vqmovn_s32(output_val_s32_1));
4360     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4361                                               vqmovn_s32(output_val_s32_3));
4362     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4363                                            vqmovun_s16(output_val_s16_1));
4364 
4365     // Perform the bit-masking with the bit masks computed at the beginning,
4366     // see the comment there.
4367     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4368     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4369 
4370     // Store back to memory
4371     vst1q_u8(output_data + c, output_val_u8);
4372   }
4373 #endif
4374   // Leftover loop: handle one value at a time with scalar code.
4375   for (; c < size; ++c) {
4376     const uint8 input_val_u8 = input_data[c];
4377     const int32 input_val_centered =
4378         static_cast<int32>(input_val_u8) - input_zero_point;
4379     uint8 output_val;
4380     if (input_val_centered < -input_range_radius) {
4381       output_val = 0;
4382     } else if (input_val_centered > input_range_radius) {
4383       output_val = 255;
4384     } else {
4385       const int32 input_val_rescaled =
4386           MultiplyByQuantizedMultiplierGreaterThanOne(
4387               input_val_centered, input_multiplier, input_left_shift);
4388       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4389       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4390       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4391       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
4392       using gemmlowp::RoundingDivideByPOT;
4393       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
4394       if (output_val_s32 == 256) {
4395         output_val_s32 = 255;
4396       }
4397       TFLITE_DCHECK_GE(output_val_s32, 0);
4398       TFLITE_DCHECK_LE(output_val_s32, 255);
4399       output_val = static_cast<uint8>(output_val_s32);
4400     }
4401     output_data[c] = output_val;
4402   }
4403 }
4404 
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4405 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
4406                      int32 input_zero_point, int32 input_range_radius,
4407                      int32 input_multiplier, int input_left_shift,
4408                      uint8* output_data, const RuntimeShape& output_shape) {
4409   LogisticParams params;
4410   params.input_zero_point = input_zero_point;
4411   params.input_range_radius = input_range_radius;
4412   params.input_multiplier = input_multiplier;
4413   params.input_left_shift = input_left_shift;
4414   Logistic(params, input_shape, input_data, output_shape, output_data);
4415 }
4416 
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4417 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
4418                      float* output_data, const Dims<4>& output_dims) {
4419   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4420            output_data);
4421 }
4422 
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4423 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
4424                      int32 input_zero_point, int32 input_range_radius,
4425                      int32 input_multiplier, int input_left_shift,
4426                      uint8* output_data, const Dims<4>& output_dims) {
4427   Logistic(input_data, DimsToShape(input_dims), input_zero_point,
4428            input_range_radius, input_multiplier, input_left_shift, output_data,
4429            DimsToShape(output_dims));
4430 }
4431 
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4432 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
4433                      const RuntimeShape& output_shape, int16* output_data) {
4434   LogisticParams params;
4435   // No params currently needed by int16 Logistic.
4436   Logistic(params, input_shape, input_data, output_shape, output_data);
4437 }
4438 
Logistic(const int16 * input_data,const RuntimeShape & input_shape,int16 * output_data,const RuntimeShape & output_shape)4439 inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
4440                      int16* output_data, const RuntimeShape& output_shape) {
4441   LogisticParams params;
4442   // No params currently needed by int16 Logistic.
4443   Logistic(params, input_shape, input_data, output_shape, output_data);
4444 }
4445 
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)4446 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
4447                      int16* output_data, const Dims<4>& output_dims) {
4448   Logistic(input_data, DimsToShape(input_dims), output_data,
4449            DimsToShape(output_dims));
4450 }
4451 
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4452 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
4453                  float* output_data, const Dims<4>& output_dims) {
4454   Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4455        output_data);
4456 }
4457 
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4458 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4459                  const uint8* input_data, const RuntimeShape& output_shape,
4460                  uint8* output_data) {
4461   // Note that this is almost the exact same code as in Logistic().
4462   ruy::profiler::ScopeLabel label("Tanh");
4463   const int32 input_zero_point = params.input_zero_point;
4464   const int32 input_range_radius = params.input_range_radius;
4465   const int32 input_multiplier = params.input_multiplier;
4466   const int input_left_shift = params.input_left_shift;
4467   const int size = MatchingFlatSize(input_shape, output_shape);
4468 
4469   int c = 0;
4470   int32_t output_zero_point = 128;
4471 #ifdef USE_NEON
4472   // Handle 16 values at a time
4473   for (; c <= size - 16; c += 16) {
4474     // Read input uint8 values, cast to int16 and subtract input_zero_point
4475     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4476     int16x8_t input_val_centered_0 =
4477         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4478                   vdupq_n_s16(input_zero_point));
4479     int16x8_t input_val_centered_1 =
4480         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4481                   vdupq_n_s16(input_zero_point));
4482 
4483     // Prepare the bit masks that we will use at the end to implement the logic
4484     // that was expressed in the scalar code with branching:
4485     //   if (input_val_centered < -input_range_radius) {
4486     //     output_val = 0;
4487     //   } else if (input_val_centered > input_range_radius) {
4488     //     output_val = 255;
4489     //   } else {
4490     //     ...
4491     uint16x8_t mask_rightclamp_0 =
4492         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4493     uint16x8_t mask_rightclamp_1 =
4494         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4495     uint16x8_t mask_leftclamp_0 =
4496         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4497     uint16x8_t mask_leftclamp_1 =
4498         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4499     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4500                                              vshrn_n_u16(mask_rightclamp_1, 8));
4501     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4502                                             vshrn_n_u16(mask_leftclamp_1, 8));
4503 
4504     // This performs what is expressed in the scalar code as
4505     // const int32 input_val_rescaled =
4506     //     MultiplyByQuantizedMultiplierGreaterThanOne(
4507     //         input_val_centered, input_multiplier, input_left_shift);
4508     int32x4_t input_val_rescaled_0 =
4509         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4510                   vdupq_n_s32(input_left_shift));
4511     int32x4_t input_val_rescaled_1 =
4512         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4513                   vdupq_n_s32(input_left_shift));
4514     int32x4_t input_val_rescaled_2 =
4515         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4516                   vdupq_n_s32(input_left_shift));
4517     int32x4_t input_val_rescaled_3 =
4518         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4519                   vdupq_n_s32(input_left_shift));
4520     input_val_rescaled_0 =
4521         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4522     input_val_rescaled_1 =
4523         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4524     input_val_rescaled_2 =
4525         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4526     input_val_rescaled_3 =
4527         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4528 
4529     // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
4530     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4531     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4532     const FixedPoint4 input_val_f4_0 =
4533         FixedPoint4::FromRaw(input_val_rescaled_0);
4534     const FixedPoint4 input_val_f4_1 =
4535         FixedPoint4::FromRaw(input_val_rescaled_1);
4536     const FixedPoint4 input_val_f4_2 =
4537         FixedPoint4::FromRaw(input_val_rescaled_2);
4538     const FixedPoint4 input_val_f4_3 =
4539         FixedPoint4::FromRaw(input_val_rescaled_3);
4540     const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
4541     const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
4542     const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
4543     const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
4544 
4545     // Divide by 2^24 as in the scalar code
4546     using gemmlowp::RoundingDivideByPOT;
4547     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
4548     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
4549     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
4550     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
4551 
4552     // Add the output zero point
4553     int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
4554     output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
4555     output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
4556     output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
4557     output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
4558 
4559     // Cast output values to uint8, saturating
4560     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4561                                               vqmovn_s32(output_val_s32_1));
4562     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4563                                               vqmovn_s32(output_val_s32_3));
4564     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4565                                            vqmovun_s16(output_val_s16_1));
4566 
4567     // Perform the bit-masking with the bit masks computed at the beginning,
4568     // see the comment there.
4569     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4570     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4571 
4572     // Store back to memory
4573     vst1q_u8(output_data + c, output_val_u8);
4574   }
4575 #endif
4576   // Leftover loop: handle one value at a time with scalar code.
4577   for (; c < size; ++c) {
4578     const uint8 input_val_u8 = input_data[c];
4579     const int32 input_val_centered =
4580         static_cast<int32>(input_val_u8) - input_zero_point;
4581     uint8 output_val;
4582     if (input_val_centered < -input_range_radius) {
4583       output_val = 0;
4584     } else if (input_val_centered > input_range_radius) {
4585       output_val = 255;
4586     } else {
4587       const int32 input_val_rescaled =
4588           MultiplyByQuantizedMultiplierGreaterThanOne(
4589               input_val_centered, input_multiplier, input_left_shift);
4590       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4591       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4592       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4593       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
4594       using gemmlowp::RoundingDivideByPOT;
4595       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
4596       output_val_s32 += output_zero_point;
4597       if (output_val_s32 == 256) {
4598         output_val_s32 = 255;
4599       }
4600       TFLITE_DCHECK_GE(output_val_s32, 0);
4601       TFLITE_DCHECK_LE(output_val_s32, 255);
4602       output_val = static_cast<uint8>(output_val_s32);
4603     }
4604     output_data[c] = output_val;
4605   }
4606 }
4607 
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4608 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
4609                  int32 input_zero_point, int32 input_range_radius,
4610                  int32 input_multiplier, int input_left_shift,
4611                  uint8* output_data, const RuntimeShape& output_shape) {
4612   TanhParams params;
4613   params.input_zero_point = input_zero_point;
4614   params.input_range_radius = input_range_radius;
4615   params.input_multiplier = input_multiplier;
4616   params.input_left_shift = input_left_shift;
4617   Tanh(params, input_shape, input_data, output_shape, output_data);
4618 }
4619 
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4620 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
4621                  int32 input_zero_point, int32 input_range_radius,
4622                  int32 input_multiplier, int input_left_shift,
4623                  uint8* output_data, const Dims<4>& output_dims) {
4624   Tanh(input_data, DimsToShape(input_dims), input_zero_point,
4625        input_range_radius, input_multiplier, input_left_shift, output_data,
4626        DimsToShape(output_dims));
4627 }
4628 
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)4629 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
4630                  int input_left_shift, int16* output_data,
4631                  const RuntimeShape& output_shape) {
4632   TanhParams params;
4633   params.input_left_shift = input_left_shift;
4634   Tanh(params, input_shape, input_data, output_shape, output_data);
4635 }
4636 
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)4637 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
4638                  int input_left_shift, int16* output_data,
4639                  const Dims<4>& output_dims) {
4640   Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
4641        DimsToShape(output_dims));
4642 }
4643 
4644 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4645 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
4646                          int block_size, T* output_data,
4647                          const Dims<4>& output_dims) {
4648   tflite::DepthToSpaceParams op_params;
4649   op_params.block_size = block_size;
4650 
4651   DepthToSpace(op_params, DimsToShape(input_dims), input_data,
4652                DimsToShape(output_dims), output_data);
4653 }
4654 
4655 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4656 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
4657                          int block_size, T* output_data,
4658                          const Dims<4>& output_dims) {
4659   tflite::SpaceToDepthParams op_params;
4660   op_params.block_size = block_size;
4661 
4662   SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
4663                DimsToShape(output_dims), output_data);
4664 }
4665 
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)4666 inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
4667                 const float* input2_data, const Dims<4>& input2_dims,
4668                 float output_activation_min, float output_activation_max,
4669                 float* output_data, const Dims<4>& output_dims) {
4670   tflite::ArithmeticParams op_params;
4671   op_params.float_activation_min = output_activation_min;
4672   op_params.float_activation_max = output_activation_max;
4673 
4674   Mul(op_params, DimsToShape(input1_dims), input1_data,
4675       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4676       output_data);
4677 }
4678 
4679 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4680 void Mul(const float* input1_data, const Dims<4>& input1_dims,
4681          const float* input2_data, const Dims<4>& input2_dims,
4682          float* output_data, const Dims<4>& output_dims) {
4683   float output_activation_min, output_activation_max;
4684   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
4685 
4686   Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
4687       output_activation_max, output_data, output_dims);
4688 }
4689 
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 output_activation_min,int32 output_activation_max,int32 * output_data,const Dims<4> & output_dims)4690 inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4691                 const int32* input2_data, const Dims<4>& input2_dims,
4692                 int32 output_activation_min, int32 output_activation_max,
4693                 int32* output_data, const Dims<4>& output_dims) {
4694   tflite::ArithmeticParams op_params;
4695   op_params.quantized_activation_min = output_activation_min;
4696   op_params.quantized_activation_max = output_activation_max;
4697 
4698   Mul(op_params, DimsToShape(input1_dims), input1_data,
4699       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4700       output_data);
4701 }
4702 
4703 template <FusedActivationFunctionType Ac>
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)4704 void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4705          const int32* input2_data, const Dims<4>& input2_dims,
4706          int32* output_data, const Dims<4>& output_dims) {
4707   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
4708   tflite::ArithmeticParams op_params;
4709   // No parameters needed.
4710 
4711   MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
4712                   DimsToShape(input2_dims), input2_data,
4713                   DimsToShape(output_dims), output_data);
4714 }
4715 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)4716 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4717                 const int16* input2_data, const Dims<4>& input2_dims,
4718                 int16* output_data, const Dims<4>& output_dims) {
4719   tflite::ArithmeticParams op_params;
4720   // No parameters needed.
4721 
4722   Mul(op_params, DimsToShape(input1_dims), input1_data,
4723       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4724       output_data);
4725 }
4726 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)4727 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4728                 const int16* input2_data, const Dims<4>& input2_dims,
4729                 int32 output_offset, int32 output_activation_min,
4730                 int32 output_activation_max, uint8* output_data,
4731                 const Dims<4>& output_dims) {
4732   tflite::ArithmeticParams op_params;
4733   op_params.output_offset = output_offset;
4734   op_params.quantized_activation_min = output_activation_min;
4735   op_params.quantized_activation_max = output_activation_max;
4736 
4737   Mul(op_params, DimsToShape(input1_dims), input1_data,
4738       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4739       output_data);
4740 }
4741 
4742 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)4743 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
4744                   const T* input2_data, const Dims<4>& input2_dims,
4745                   T output_activation_min, T output_activation_max,
4746                   T* output_data, const Dims<4>& output_dims) {
4747   tflite::ArithmeticParams op_params;
4748   SetActivationParams(output_activation_min, output_activation_max, &op_params);
4749 
4750   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4751                      DimsToShape(input2_dims), input2_data,
4752                      DimsToShape(output_dims), output_data);
4753 }
4754 
4755 // For compatibility with old checked-in code
4756 template <FusedActivationFunctionType Ac>
BroadcastMul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4757 inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
4758                          const float* input2_data, const Dims<4>& input2_dims,
4759                          float* output_data, const Dims<4>& output_dims) {
4760   tflite::ArithmeticParams op_params;
4761   float float_activation_min;
4762   float float_activation_max;
4763   GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
4764   SetActivationParams(float_activation_min, float_activation_max, &op_params);
4765 
4766   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4767                      DimsToShape(input2_dims), input2_data,
4768                      DimsToShape(output_dims), output_data);
4769 }
4770 
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)4771 inline void LocalResponseNormalization(const float* input_data,
4772                                        const Dims<4>& input_dims, int range,
4773                                        float bias, float alpha, float beta,
4774                                        float* output_data,
4775                                        const Dims<4>& output_dims) {
4776   tflite::LocalResponseNormalizationParams op_params;
4777   op_params.range = range;
4778   op_params.bias = bias;
4779   op_params.alpha = alpha;
4780   op_params.beta = beta;
4781 
4782   LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
4783                              DimsToShape(output_dims), output_data);
4784 }
4785 
4786 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)4787 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
4788           const Dims<4>& output_dims) {
4789   Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4790        output_data);
4791 }
4792 
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4793 inline void Floor(const float* input_data, const Dims<4>& input_dims,
4794                   float* output_data, const Dims<4>& output_dims) {
4795   Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4796         output_data);
4797 }
4798 
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims,bool align_corners)4799 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4800                            const int32* output_size_data,
4801                            const Dims<4>& output_size_dims, float* output_data,
4802                            const Dims<4>& output_dims, bool align_corners) {
4803   tflite::ResizeBilinearParams op_params;
4804   op_params.align_corners = align_corners;
4805   op_params.half_pixel_centers = false;
4806   ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4807                  DimsToShape(output_size_dims), output_size_data,
4808                  DimsToShape(output_dims), output_data);
4809 }
4810 
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims,bool align_corners)4811 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4812                            const int32* output_size_data,
4813                            const Dims<4>& output_size_dims, uint8* output_data,
4814                            const Dims<4>& output_dims, bool align_corners) {
4815   tflite::ResizeBilinearParams op_params;
4816   op_params.align_corners = align_corners;
4817   op_params.half_pixel_centers = false;
4818   ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4819                  DimsToShape(output_size_dims), output_size_data,
4820                  DimsToShape(output_dims), output_data);
4821 }
4822 
4823 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)4824 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4825                            const int32* output_size_data,
4826                            const Dims<4>& output_size_dims, float* output_data,
4827                            const Dims<4>& output_dims) {
4828   ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4829                  output_data, output_dims, /*align_corners=*/false);
4830 }
4831 
4832 // legacy, for compatibility with old checked-in code
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)4833 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4834                            const int32* output_size_data,
4835                            const Dims<4>& output_size_dims, uint8* output_data,
4836                            const Dims<4>& output_dims) {
4837   ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4838                  output_data, output_dims, /*align_corners=*/false);
4839 }
4840 
4841 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)4842 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
4843                            const int32* block_shape_data,
4844                            const Dims<4>& block_shape_dims,
4845                            const int32* crops_data, const Dims<4>& crops_dims,
4846                            T* output_data, const Dims<4>& output_dims) {
4847   BatchToSpaceND(DimsToShape(input_dims), input_data,
4848                  DimsToShape(block_shape_dims), block_shape_data,
4849                  DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
4850                  output_data);
4851 }
4852 
4853 // Legacy signature, function covered both Pad and PadV2.
4854 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)4855 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
4856                   const std::vector<int>& left_paddings,
4857                   const std::vector<int>& right_paddings, T* output_data,
4858                   const Dims<4>& output_dims, const T pad_value) {
4859   TFLITE_DCHECK_EQ(left_paddings.size(), 4);
4860   TFLITE_DCHECK_EQ(right_paddings.size(), 4);
4861   tflite::PadParams op_params;
4862   op_params.left_padding_count = 4;
4863   op_params.right_padding_count = 4;
4864   for (int i = 0; i < 4; ++i) {
4865     op_params.left_padding[i] = left_paddings[3 - i];
4866     op_params.right_padding[i] = right_paddings[3 - i];
4867   }
4868   const T pad_value_copy = pad_value;
4869 
4870   Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
4871       DimsToShape(output_dims), output_data);
4872 }
4873 
4874 // Old Pad that calls legacy PadV2.
4875 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)4876 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4877                 const std::vector<int>& left_paddings,
4878                 const std::vector<int>& right_paddings, T* output_data,
4879                 const Dims<4>& output_dims, const int32_t pad_value) {
4880   const T converted_pad_value = static_cast<T>(pad_value);
4881   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4882            output_dims, converted_pad_value);
4883 }
4884 
4885 // Old Pad that only padded with 0.
4886 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)4887 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4888                 const std::vector<int>& left_paddings,
4889                 const std::vector<int>& right_paddings, T* output_data,
4890                 const Dims<4>& output_dims) {
4891   const T pad_value = static_cast<T>(0);
4892   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4893            output_dims, pad_value);
4894 }
4895 
4896 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)4897 inline void Slice(const T* input_data, const Dims<4>& input_dims,
4898                   const std::vector<int>& begin, const std::vector<int>& size,
4899                   T* output_data, const Dims<4>& output_dims) {
4900   tflite::SliceParams op_params;
4901   op_params.begin_count = 4;
4902   op_params.size_count = 4;
4903   for (int i = 0; i < 4; ++i) {
4904     op_params.begin[i] = begin[3 - i];
4905     op_params.size[i] = size[3 - i];
4906   }
4907 
4908   Slice(op_params, DimsToShape(input_dims), input_data,
4909         DimsToShape(output_dims), output_data);
4910 }
4911 
4912 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4913 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
4914                        const T* input2_data, T* output_data,
4915                        const Dims<4>& output_dims) {
4916   Minimum(DimsToShape(input1_dims), input1_data, input2_data,
4917           DimsToShape(output_dims), output_data);
4918 }
4919 
4920 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4921 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
4922                        const T* input2_data, T* output_data,
4923                        const Dims<4>& output_dims) {
4924   Maximum(DimsToShape(input1_dims), input1_data, input2_data,
4925           DimsToShape(output_dims), output_data);
4926 }
4927 
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)4928 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
4929                        int32 zero_point, double scale, float* output_data,
4930                        const Dims<4>& output_dims) {
4931   tflite::DequantizationParams op_params;
4932   op_params.zero_point = zero_point;
4933   op_params.scale = scale;
4934 
4935   Dequantize(op_params, DimsToShape(input_dims), input_data,
4936              DimsToShape(output_dims), output_data);
4937 }
4938 
4939 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)4940 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
4941                const Dims<4>& output_dims, const int* permuted_axes) {
4942   TransposeParams params;
4943   params.perm_count = 4;
4944   for (int i = 0; i < 4; ++i) {
4945     params.perm[i] = 3 - permuted_axes[3 - i];
4946   }
4947   Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
4948             output);
4949 }
4950 
4951 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)4952 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
4953                          int begin_mask, int end_mask, int shrink_axis_mask,
4954                          const std::vector<int>& start_indices,
4955                          const std::vector<int>& stop_indices,
4956                          const std::vector<int>& strides, T* output_data,
4957                          const Dims<4>& output_dims) {
4958   TFLITE_DCHECK_EQ(start_indices.size(), 4);
4959   auto op_params = strided_slice::BuildStridedSliceParams(
4960       begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
4961       strides);
4962   reference_ops::StridedSliceReverseIndices(&op_params);
4963 
4964   StridedSlice(op_params, DimsToShape(input_dims), input_data,
4965                DimsToShape(output_dims), output_data);
4966 }
4967 
4968 }  // namespace optimized_ops
4969 }  // namespace tflite
4970 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
4971